diff --git a/src/lerobot/rewards/robometer/compute_rabc_weights.py b/src/lerobot/rewards/robometer/compute_rabc_weights.py index 87d844340..8e2b016fb 100644 --- a/src/lerobot/rewards/robometer/compute_rabc_weights.py +++ b/src/lerobot/rewards/robometer/compute_rabc_weights.py @@ -108,10 +108,13 @@ def compute_robometer_progress( batch_size: int = 32, num_subsampled_frames: int = DEFAULT_NUM_SUBSAMPLED_FRAMES, episodes: list[int] | None = None, + image_key: str | None = None, ) -> Path: """Run Robometer over a dataset and write per-frame progress + success.""" logging.info(f"Loading Robometer: {reward_model_path}") config = RobometerConfig(pretrained_path=reward_model_path, device=device) + if image_key is not None: + config.image_key = image_key model = RobometerRewardModel.from_pretrained(reward_model_path, config=config) model.to(device).eval() @@ -248,6 +251,9 @@ Examples: parser.add_argument( "--episodes", type=int, nargs="+", default=None, help="Process only these episode indices." ) + parser.add_argument( + "--image-key", type=str, default=None, help="Image observation key (default: from config)." + ) parser.add_argument( "--push-to-hub", action="store_true", help="Upload to the dataset repo on HuggingFace Hub." ) @@ -276,6 +282,7 @@ Examples: batch_size=args.batch_size, num_subsampled_frames=args.num_subsampled_frames, episodes=args.episodes, + image_key=args.image_key, ) print(f"\nRobometer progress saved to: {output_path}")