mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 04:41:24 +00:00
fix
This commit is contained in:
@@ -212,6 +212,28 @@ def train(cfg: TrainPipelineConfig):
|
||||
ds_meta=dataset.meta,
|
||||
episode_data_index=episode_data_index,
|
||||
)
|
||||
|
||||
# Setup RLearN evaluation visualizations if enabled
|
||||
eval_visualizer = None
|
||||
eval_holdout_episodes = None
|
||||
if (getattr(cfg.policy, "type", None) == "rlearn" and
|
||||
getattr(cfg.policy, "enable_eval_visualizations", False)):
|
||||
|
||||
try:
|
||||
from lerobot.policies.rlearn.eval_visualizer import RLearNEvalVisualizer, select_evaluation_episodes
|
||||
|
||||
logging.info("Setting up RLearN evaluation visualizations")
|
||||
eval_visualizer = RLearNEvalVisualizer(policy, dataset, device=str(device))
|
||||
eval_holdout_episodes = select_evaluation_episodes(
|
||||
dataset,
|
||||
num_episodes=getattr(cfg.policy, "eval_holdout_episodes", 9),
|
||||
seed=getattr(cfg.policy, "eval_visualization_seed", 42)
|
||||
)
|
||||
logging.info(f"Selected {len(eval_holdout_episodes)} holdout episodes for evaluation: {eval_holdout_episodes}")
|
||||
except ImportError as e:
|
||||
logging.warning(f"Could not setup RLearN evaluation visualizations: {e}")
|
||||
eval_visualizer = None
|
||||
|
||||
preprocessor, postprocessor = make_processor(
|
||||
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, dataset_stats=dataset.meta.stats
|
||||
)
|
||||
@@ -386,6 +408,8 @@ def train(cfg: TrainPipelineConfig):
|
||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
|
||||
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
||||
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
|
||||
is_eval_viz_step = (eval_visualizer is not None and
|
||||
step % getattr(cfg.policy, "eval_visualization_freq", 1000) == 0)
|
||||
|
||||
if is_log_step:
|
||||
logging.info(train_tracker)
|
||||
@@ -437,6 +461,87 @@ def train(cfg: TrainPipelineConfig):
|
||||
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
|
||||
wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
||||
|
||||
# RLearN evaluation visualizations
|
||||
if is_eval_viz_step:
|
||||
logging.info(f"Creating RLearN evaluation visualizations at step {step}")
|
||||
try:
|
||||
with torch.no_grad():
|
||||
policy.eval()
|
||||
|
||||
# Create evaluation visualizations directory
|
||||
eval_viz_dir = cfg.output_dir / "eval_visualizations"
|
||||
eval_viz_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create reward prediction visualization (3x3 grid)
|
||||
reward_viz_path = eval_viz_dir / f"reward_predictions_step_{step:06d}.png"
|
||||
reward_metrics = eval_visualizer.create_episode_grid_visualization(
|
||||
episode_indices=eval_holdout_episodes,
|
||||
save_path=reward_viz_path,
|
||||
step=step,
|
||||
max_frames=getattr(cfg.policy, "eval_max_frames", 128)
|
||||
)
|
||||
|
||||
# Log metrics
|
||||
eval_viz_metrics = {
|
||||
"eval_viz/mean_voc_s": reward_metrics["mean_voc_s"],
|
||||
"eval_viz/std_voc_s": reward_metrics["std_voc_s"],
|
||||
"eval_viz/valid_episodes": reward_metrics["num_valid_episodes"],
|
||||
"eval_viz/total_episodes": reward_metrics["total_episodes"],
|
||||
"eval_viz/mean_episode_length": reward_metrics["mean_episode_length"],
|
||||
}
|
||||
|
||||
logging.info(f"RLearN Evaluation Results at Step {step}:")
|
||||
logging.info(f" Mean VOC-S: {reward_metrics['mean_voc_s']:.4f} (±{reward_metrics['std_voc_s']:.4f})")
|
||||
logging.info(f" Valid Episodes: {reward_metrics['num_valid_episodes']}/{reward_metrics['total_episodes']}")
|
||||
logging.info(f" Mean Episode Length: {reward_metrics['mean_episode_length']:.1f}")
|
||||
logging.info(f" Visualizations saved to: {eval_viz_dir}")
|
||||
|
||||
if wandb_logger:
|
||||
wandb_logger.log_dict(eval_viz_metrics, step, mode="eval_viz")
|
||||
|
||||
# Log the visualization image both as regular image and as artifact
|
||||
try:
|
||||
import wandb
|
||||
|
||||
# Log as regular image for immediate viewing in wandb UI
|
||||
wandb_logger.wandb_run.log({
|
||||
f"eval_viz/reward_predictions_step_{step}": wandb.Image(str(reward_viz_path)),
|
||||
}, step=step)
|
||||
|
||||
# Create and upload artifact with reward prediction visualization
|
||||
artifact_name = f"rlearn_reward_predictions_step_{step:06d}"
|
||||
artifact = wandb.Artifact(
|
||||
name=artifact_name,
|
||||
type="reward_prediction_visualization",
|
||||
description=f"RLearN reward prediction visualization at training step {step}",
|
||||
metadata={
|
||||
"step": step,
|
||||
"mean_voc_s": reward_metrics["mean_voc_s"],
|
||||
"std_voc_s": reward_metrics["std_voc_s"],
|
||||
"valid_episodes": reward_metrics["num_valid_episodes"],
|
||||
"total_episodes": reward_metrics["total_episodes"],
|
||||
"mean_episode_length": reward_metrics["mean_episode_length"],
|
||||
"holdout_episodes": eval_holdout_episodes,
|
||||
}
|
||||
)
|
||||
|
||||
# Add reward prediction visualization to the artifact
|
||||
artifact.add_file(str(reward_viz_path), name="reward_predictions.png")
|
||||
|
||||
# Upload the artifact
|
||||
wandb_logger.wandb_run.log_artifact(artifact)
|
||||
|
||||
logging.info(f"Uploaded wandb artifact: {artifact_name}")
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"Could not log visualization image to wandb: {e}")
|
||||
|
||||
policy.train() # Return to training mode
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error during RLearN evaluation visualization: {e}")
|
||||
# Continue training even if evaluation fails
|
||||
|
||||
if eval_env:
|
||||
eval_env.close()
|
||||
logging.info("End of training")
|
||||
|
||||
Reference in New Issue
Block a user