This commit is contained in:
Pepijn
2025-08-30 23:11:26 +02:00
parent 47bc670ad2
commit 825c0666a9
5 changed files with 716 additions and 54 deletions

View File

@@ -212,6 +212,28 @@ def train(cfg: TrainPipelineConfig):
ds_meta=dataset.meta,
episode_data_index=episode_data_index,
)
# Setup RLearN evaluation visualizations if enabled
eval_visualizer = None
eval_holdout_episodes = None
if (getattr(cfg.policy, "type", None) == "rlearn" and
getattr(cfg.policy, "enable_eval_visualizations", False)):
try:
from lerobot.policies.rlearn.eval_visualizer import RLearNEvalVisualizer, select_evaluation_episodes
logging.info("Setting up RLearN evaluation visualizations")
eval_visualizer = RLearNEvalVisualizer(policy, dataset, device=str(device))
eval_holdout_episodes = select_evaluation_episodes(
dataset,
num_episodes=getattr(cfg.policy, "eval_holdout_episodes", 9),
seed=getattr(cfg.policy, "eval_visualization_seed", 42)
)
logging.info(f"Selected {len(eval_holdout_episodes)} holdout episodes for evaluation: {eval_holdout_episodes}")
except ImportError as e:
logging.warning(f"Could not setup RLearN evaluation visualizations: {e}")
eval_visualizer = None
preprocessor, postprocessor = make_processor(
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, dataset_stats=dataset.meta.stats
)
@@ -386,6 +408,8 @@ def train(cfg: TrainPipelineConfig):
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
is_eval_viz_step = (eval_visualizer is not None and
step % getattr(cfg.policy, "eval_visualization_freq", 1000) == 0)
if is_log_step:
logging.info(train_tracker)
@@ -437,6 +461,87 @@ def train(cfg: TrainPipelineConfig):
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval")
# RLearN evaluation visualizations
if is_eval_viz_step:
logging.info(f"Creating RLearN evaluation visualizations at step {step}")
try:
with torch.no_grad():
policy.eval()
# Create evaluation visualizations directory
eval_viz_dir = cfg.output_dir / "eval_visualizations"
eval_viz_dir.mkdir(parents=True, exist_ok=True)
# Create reward prediction visualization (3x3 grid)
reward_viz_path = eval_viz_dir / f"reward_predictions_step_{step:06d}.png"
reward_metrics = eval_visualizer.create_episode_grid_visualization(
episode_indices=eval_holdout_episodes,
save_path=reward_viz_path,
step=step,
max_frames=getattr(cfg.policy, "eval_max_frames", 128)
)
# Log metrics
eval_viz_metrics = {
"eval_viz/mean_voc_s": reward_metrics["mean_voc_s"],
"eval_viz/std_voc_s": reward_metrics["std_voc_s"],
"eval_viz/valid_episodes": reward_metrics["num_valid_episodes"],
"eval_viz/total_episodes": reward_metrics["total_episodes"],
"eval_viz/mean_episode_length": reward_metrics["mean_episode_length"],
}
logging.info(f"RLearN Evaluation Results at Step {step}:")
logging.info(f" Mean VOC-S: {reward_metrics['mean_voc_s']:.4f}{reward_metrics['std_voc_s']:.4f})")
logging.info(f" Valid Episodes: {reward_metrics['num_valid_episodes']}/{reward_metrics['total_episodes']}")
logging.info(f" Mean Episode Length: {reward_metrics['mean_episode_length']:.1f}")
logging.info(f" Visualizations saved to: {eval_viz_dir}")
if wandb_logger:
wandb_logger.log_dict(eval_viz_metrics, step, mode="eval_viz")
# Log the visualization image both as regular image and as artifact
try:
import wandb
# Log as regular image for immediate viewing in wandb UI
wandb_logger.wandb_run.log({
f"eval_viz/reward_predictions_step_{step}": wandb.Image(str(reward_viz_path)),
}, step=step)
# Create and upload artifact with reward prediction visualization
artifact_name = f"rlearn_reward_predictions_step_{step:06d}"
artifact = wandb.Artifact(
name=artifact_name,
type="reward_prediction_visualization",
description=f"RLearN reward prediction visualization at training step {step}",
metadata={
"step": step,
"mean_voc_s": reward_metrics["mean_voc_s"],
"std_voc_s": reward_metrics["std_voc_s"],
"valid_episodes": reward_metrics["num_valid_episodes"],
"total_episodes": reward_metrics["total_episodes"],
"mean_episode_length": reward_metrics["mean_episode_length"],
"holdout_episodes": eval_holdout_episodes,
}
)
# Add reward prediction visualization to the artifact
artifact.add_file(str(reward_viz_path), name="reward_predictions.png")
# Upload the artifact
wandb_logger.wandb_run.log_artifact(artifact)
logging.info(f"Uploaded wandb artifact: {artifact_name}")
except Exception as e:
logging.warning(f"Could not log visualization image to wandb: {e}")
policy.train() # Return to training mode
except Exception as e:
logging.error(f"Error during RLearN evaluation visualization: {e}")
# Continue training even if evaluation fails
if eval_env:
eval_env.close()
logging.info("End of training")