From bac4f61eae9bf465d79ebb6f03e9fba52304b1c2 Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Thu, 21 May 2026 14:32:10 +0200 Subject: [PATCH 1/7] refactor: support custom progress parquet overlays (#3640) --- examples/dataset/create_progress_videos.py | 52 +++++++++++++++------- 1 file changed, 37 insertions(+), 15 deletions(-) diff --git a/examples/dataset/create_progress_videos.py b/examples/dataset/create_progress_videos.py index 5f98d2cea..cb85a9d3a 100644 --- a/examples/dataset/create_progress_videos.py +++ b/examples/dataset/create_progress_videos.py @@ -15,10 +15,12 @@ # limitations under the License. """ -Create MP4 (or GIF) videos with sarm_progress overlay for specified episodes. +Create MP4 (or GIF) videos with per-frame progress overlay for specified episodes. Downloads datasets from HuggingFace, seeks directly into the episode segment of the source video, draws a progress line on each frame, and writes the result. +The progress data is read from a parquet file that lives alongside the dataset +(configurable via ``--progress-file``). Usage: python examples/dataset/create_progress_videos.py \ @@ -56,22 +58,26 @@ SCORE_FONT_SCALE = 0.8 TASK_FONT_SCALE = 0.55 -def download_episode_metadata(repo_id: str, episode: int) -> Path: - """Download only the metadata and sarm_progress files for a dataset. +def download_episode_metadata( + repo_id: str, episode: int, progress_file: str = "sarm_progress.parquet" +) -> Path: + """Download only the metadata and per-frame progress file for a dataset. Args: repo_id: HuggingFace dataset repository ID. episode: Episode index (used for logging only; all meta is fetched). + progress_file: Filename of the per-frame progress parquet inside the + dataset repo. Returns: Local cache path for the downloaded snapshot. """ - logging.info("[1/4] Downloading metadata for %s (episode %d) ...", repo_id, episode) + logging.info("[1/4] Downloading metadata + %s for %s (episode %d) ...", progress_file, repo_id, episode) local_path = Path( snapshot_download( repo_id=repo_id, repo_type="dataset", - allow_patterns=["meta/**", "sarm_progress.parquet"], + allow_patterns=["meta/**", progress_file], ignore_patterns=["*.mp4"], ) ) @@ -215,25 +221,28 @@ def download_video_file(repo_id: str, local_path: Path, video_rel: str) -> Path: return video_path -def load_progress_data(local_path: Path, episode: int) -> np.ndarray | None: - """Load sarm_progress values for an episode. +def load_progress_data( + local_path: Path, episode: int, progress_file: str = "sarm_progress.parquet" +) -> np.ndarray | None: + """Load per-frame progress values for an episode. Args: local_path: Dataset cache root. episode: Episode index. + progress_file: Filename of the per-frame progress parquet. Returns: Sorted (N, 2) array of (frame_index, progress), or None if unavailable. """ - parquet_path = local_path / "sarm_progress.parquet" + parquet_path = local_path / progress_file if not parquet_path.exists(): - logging.warning("sarm_progress.parquet not found") + logging.warning("%s not found", progress_file) return None df = pd.read_parquet(parquet_path) - logging.info(" sarm_progress.parquet columns: %s", list(df.columns)) + logging.info(" %s columns: %s", progress_file, list(df.columns)) episode_df = df[df["episode_index"] == episode].copy() if episode_df.empty: - logging.warning("No sarm_progress rows for episode %d", episode) + logging.warning("No progress rows for episode %d in %s", episode, progress_file) return None episode_df = episode_df.sort_values("frame_index") @@ -576,6 +585,7 @@ def process_dataset( camera_key: str | None, output_dir: Path, create_gif: bool = False, + progress_file: str = "sarm_progress.parquet", ) -> Path | None: """Full pipeline: download, extract metadata, composite progress, write output. @@ -585,6 +595,8 @@ def process_dataset( camera_key: Camera key to use, or None for auto-selection. output_dir: Directory to write output files. create_gif: If True, also generate a GIF from the MP4. + progress_file: Filename of the per-frame progress parquet inside the + dataset repo. Returns: Path to the final output file, or None on failure. @@ -592,7 +604,7 @@ def process_dataset( safe_name = repo_id.replace("/", "_") logging.info("Processing: %s | episode %d", repo_id, episode) - local_path = download_episode_metadata(repo_id, episode) + local_path = download_episode_metadata(repo_id, episode, progress_file) logging.info(" Local cache: %s", local_path) episode_meta = load_episode_meta(local_path, episode, camera_key) @@ -600,9 +612,9 @@ def process_dataset( video_path = download_video_file(repo_id, local_path, episode_meta["video_rel"]) - progress_data = load_progress_data(local_path, episode) + progress_data = load_progress_data(local_path, episode, progress_file) if progress_data is None: - logging.error("Could not load sarm_progress data. Skipping overlay.") + logging.error("Could not load progress data from %s. Skipping overlay.", progress_file) return None logging.info(" Progress frames: %d", len(progress_data)) @@ -627,7 +639,7 @@ def process_dataset( def main() -> None: parser = argparse.ArgumentParser( - description="Create MP4/GIF videos with sarm_progress overlay for dataset episodes." + description="Create MP4/GIF videos with per-frame progress overlay for dataset episodes." ) parser.add_argument( "--repo-id", @@ -658,6 +670,15 @@ def main() -> None: action="store_true", help="Also generate a GIF from the MP4 output.", ) + parser.add_argument( + "--progress-file", + type=str, + default="sarm_progress.parquet", + help=( + "Filename of the per-frame progress parquet inside the dataset repo " + "(default: 'sarm_progress.parquet')." + ), + ) args = parser.parse_args() logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") @@ -670,6 +691,7 @@ def main() -> None: camera_key=args.camera_key, output_dir=args.output_dir, create_gif=args.gif, + progress_file=args.progress_file, ) if result: From c0a2e9814df2e78c8c1201ac86a9806e8bf51b8e Mon Sep 17 00:00:00 2001 From: Nikodem Bartnik <39432165+NikodemBartnik@users.noreply.github.com> Date: Thu, 21 May 2026 22:14:07 +0200 Subject: [PATCH 2/7] fix examples (#3623) - Fixed broken API examples in Lerobot Imitation Learning Documentation - Teleoperation with cameras improved by adding a fixed frequency in the loop (without it the cameras feed gets very slow) - Wrapped record example script in main() to avoid problems on Mac - Previously teleoperation example was using SO-ARM and teleoperation with cameras was using Koch. I changed it to use SO-ARM in all of the examples. - Added section on how to train with HF Jobs - CLI and Python examples - Replaced lerobot-record with lerobot-rollout in policies examples --- docs/source/act.mdx | 16 +- docs/source/groot.mdx | 10 +- docs/source/il_robots.mdx | 317 +++++++++++++++++++++++++------------- docs/source/smolvla.mdx | 16 +- 4 files changed, 228 insertions(+), 131 deletions(-) diff --git a/docs/source/act.mdx b/docs/source/act.mdx index 8e91edcf9..f64246d7a 100644 --- a/docs/source/act.mdx +++ b/docs/source/act.mdx @@ -79,17 +79,13 @@ If your local computer doesn't have a powerful GPU, you can utilize Google Colab Once training is complete, you can evaluate your ACT policy using the `lerobot-record` command with your trained policy. This will run inference and record evaluation episodes: ```bash -lerobot-record \ - --robot.type=so100_follower \ +lerobot-rollout \ + --strategy.type=base \ + --policy.path=${HF_USER}/act_policy \ + --robot.type=so101_follower \ --robot.port=/dev/ttyACM0 \ - --robot.id=my_robot \ --robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \ --display_data=true \ - --dataset.repo_id=${HF_USER}/eval_act_your_dataset \ - --dataset.num_episodes=10 \ - --dataset.single_task="Your task description" \ - --dataset.streaming_encoding=true \ - --dataset.encoder_threads=2 \ - # --dataset.camera_encoder.vcodec=auto \ - --policy.path=${HF_USER}/act_policy + --task="Your task description" \ # can be skipped for ACT + --duration=60 ``` diff --git a/docs/source/groot.mdx b/docs/source/groot.mdx index d69d10a57..a10b5e369 100644 --- a/docs/source/groot.mdx +++ b/docs/source/groot.mdx @@ -105,10 +105,12 @@ These results demonstrate GR00T's strong generalization capabilities across dive ### Evaluate in your hardware setup -Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Imitation Learning for Robots](./il_robots). For example: +Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Policy Deployment (lerobot-rollout)](./inference). For example: ```bash -lerobot-record \ +lerobot-rollout\ + --strategy.type=sentry \ + --strategy.upload_every_n_episodes=5 \ --robot.type=bi_so_follower \ --robot.left_arm_port=/dev/ttyACM1 \ --robot.right_arm_port=/dev/ttyACM0 \ @@ -119,14 +121,12 @@ lerobot-record \ }' \ --display_data=true \ --dataset.repo_id=/eval_groot-bimanual \ - --dataset.num_episodes=10 \ --dataset.single_task="Grab and handover the red cube to the other arm" \ --dataset.streaming_encoding=true \ --dataset.encoder_threads=2 \ # --dataset.camera_encoder.vcodec=auto \ --policy.path=/groot-bimanual \ # your trained model - --dataset.episode_time_s=30 \ - --dataset.reset_time_s=10 + --duration=600 ``` ## License diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx index 07789225a..dc2e02737 100644 --- a/docs/source/il_robots.mdx +++ b/docs/source/il_robots.mdx @@ -68,13 +68,13 @@ from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig robot_config = SO101FollowerConfig( - port="/dev/tty.usbmodem58760431541", - id="my_red_robot_arm", + port="/dev/tty.usbmodem5AB90687491", + id="my_follower_arm", ) teleop_config = SO101LeaderConfig( - port="/dev/tty.usbmodem58760431551", - id="my_blue_leader_arm", + port="/dev/tty.usbmodem5AB90689011", + id="my_leader_arm", ) robot = SO101Follower(robot_config) @@ -108,13 +108,13 @@ With `rerun`, you can teleoperate again while simultaneously visualizing the cam ```bash lerobot-teleoperate \ - --robot.type=koch_follower \ - --robot.port=/dev/tty.usbmodem58760431541 \ - --robot.id=my_awesome_follower_arm \ - --robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \ - --teleop.type=koch_leader \ - --teleop.port=/dev/tty.usbmodem58760431551 \ - --teleop.id=my_awesome_leader_arm \ + --robot.type=so101_follower \ + --robot.port=/dev/tty.usbmodem5AB90687491 \ + --robot.id=my_follower_arm \ + --robot.cameras="{front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \ + --teleop.type=so101_leader \ + --teleop.port=/dev/tty.usbmodem5AB90689011 \ + --teleop.id=my_leader_arm \ --display_data=true ``` @@ -122,34 +122,48 @@ lerobot-teleoperate \ ```python +import time +from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig +from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig from lerobot.cameras.opencv import OpenCVCameraConfig -from lerobot.teleoperators.koch_leader import KochLeader, KochLeaderConfig -from lerobot.robots.koch_follower import KochFollower, KochFollowerConfig +from lerobot.utils.visualization_utils import init_rerun, log_rerun_data, shutdown_rerun -camera_config = { - "front": OpenCVCameraConfig(index_or_path=0, width=1920, height=1080, fps=30) -} - -robot_config = KochFollowerConfig( - port="/dev/tty.usbmodem585A0076841", - id="my_red_robot_arm", - cameras=camera_config +robot_config = SO101FollowerConfig( + port="/dev/tty.usbmodem5AB90687491", + id="my_follower_arm", + cameras={ + "wrist": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30), + "top": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30) + } ) -teleop_config = KochLeaderConfig( - port="/dev/tty.usbmodem58760431551", - id="my_blue_leader_arm", +teleop_config = SO101LeaderConfig( + port="/dev/tty.usbmodem5AB90689011", + id="my_leader_arm", ) -robot = KochFollower(robot_config) -teleop_device = KochLeader(teleop_config) +init_rerun(session_name="teleoperation") + +robot = SO101Follower(robot_config) +teleop_device = SO101Leader(teleop_config) robot.connect() teleop_device.connect() +TARGET_HZ = 30 +TIME_PER_FRAME = 1.0 / TARGET_HZ + while True: + start_time = time.perf_counter() + observation = robot.get_observation() action = teleop_device.get_action() robot.send_action(action) + log_rerun_data(observation=observation, action=action) + + elapsed_time = time.perf_counter() - start_time + sleep_time = TIME_PER_FRAME - elapsed_time + if sleep_time > 0: + time.sleep(sleep_time) ``` @@ -202,10 +216,11 @@ lerobot-record \ ```python from lerobot.cameras.opencv import OpenCVCameraConfig -from lerobot.datasets import LeRobotDataset +from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.utils.feature_utils import hw_to_dataset_features -from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig -from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig +from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig +from lerobot.teleoperators.so_leader.config_so_leader import SO101LeaderConfig +from lerobot.teleoperators.so_leader.so_leader import SO101Leader from lerobot.common.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun @@ -218,71 +233,56 @@ EPISODE_TIME_SEC = 60 RESET_TIME_SEC = 10 TASK_DESCRIPTION = "My task description" -# Create robot configuration -robot_config = SO100FollowerConfig( - id="my_awesome_follower_arm", - cameras={ - "front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS) # Optional: fourcc="MJPG" for troubleshooting OpenCV async error. - }, - port="/dev/tty.usbmodem58760434471", -) - -teleop_config = SO100LeaderConfig( - id="my_awesome_leader_arm", - port="/dev/tty.usbmodem585A0077581", -) - -# Initialize the robot and teleoperator -robot = SO100Follower(robot_config) -teleop = SO100Leader(teleop_config) - -# Configure the dataset features -action_features = hw_to_dataset_features(robot.action_features, "action") -obs_features = hw_to_dataset_features(robot.observation_features, "observation") -dataset_features = {**action_features, **obs_features} - -# Create the dataset -dataset = LeRobotDataset.create( - repo_id="/", - fps=FPS, - features=dataset_features, - robot_type=robot.name, - use_videos=True, - image_writer_threads=4, -) - -# Initialize the keyboard listener and rerun visualization -_, events = init_keyboard_listener() -init_rerun(session_name="recording") - -# Connect the robot and teleoperator -robot.connect() -teleop.connect() - -# Create the required processors -teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors() - -episode_idx = 0 -while episode_idx < NUM_EPISODES and not events["stop_recording"]: - log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}") - - record_loop( - robot=robot, - events=events, - fps=FPS, - teleop_action_processor=teleop_action_processor, - robot_action_processor=robot_action_processor, - robot_observation_processor=robot_observation_processor, - teleop=teleop, - dataset=dataset, - control_time_s=EPISODE_TIME_SEC, - single_task=TASK_DESCRIPTION, - display_data=True, +def main(): + # Create robot configuration + robot_config = SO101FollowerConfig( + port="/dev/tty.usbmodem5AB90687491", + id="my_follower_arm", + cameras={ + "wrist": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30), + "top": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30) + } ) - # Reset the environment if not stopping or re-recording - if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]): - log_say("Reset the environment") + teleop_config = SO101LeaderConfig( + port="/dev/tty.usbmodem5AB90689011", + id="my_leader_arm", + ) + + # Initialize the robot and teleoperator + robot = SO101Follower(robot_config) + teleop = SO101Leader(teleop_config) + + # Configure the dataset features + action_features = hw_to_dataset_features(robot.action_features, "action") + obs_features = hw_to_dataset_features(robot.observation_features, "observation") + dataset_features = {**action_features, **obs_features} + + # Create the dataset + dataset = LeRobotDataset.create( + repo_id="/", + fps=FPS, + features=dataset_features, + robot_type=robot.name, + use_videos=True, + image_writer_threads=4, + ) + + # Initialize the keyboard listener and rerun visualization + _, events = init_keyboard_listener() + init_rerun(session_name="recording") + + # Connect the robot and teleoperator + robot.connect() + teleop.connect() + + # Create the required processors + teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors() + + episode_idx = 0 + while episode_idx < NUM_EPISODES and not events["stop_recording"]: + log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}") + record_loop( robot=robot, events=events, @@ -291,26 +291,50 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]: robot_action_processor=robot_action_processor, robot_observation_processor=robot_observation_processor, teleop=teleop, - control_time_s=RESET_TIME_SEC, + dataset=dataset, + control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, ) - if events["rerecord_episode"]: - log_say("Re-recording episode") - events["rerecord_episode"] = False - events["exit_early"] = False - dataset.clear_episode_buffer() - continue + # Reset the environment if not stopping or re-recording + if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]): + log_say("Reset the environment") + record_loop( + robot=robot, + events=events, + fps=FPS, + teleop_action_processor=teleop_action_processor, + robot_action_processor=robot_action_processor, + robot_observation_processor=robot_observation_processor, + teleop=teleop, + control_time_s=RESET_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + ) - dataset.save_episode() - episode_idx += 1 + if events["rerecord_episode"]: + log_say("Re-recording episode") + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue -# Clean up -log_say("Stop recording") -robot.disconnect() -teleop.disconnect() -dataset.push_to_hub() + dataset.save_episode() + episode_idx += 1 + + # finalize dataset + log_say("Finalizing dataset...") + dataset.finalize() + # Clean up + log_say("Stop recording") + robot.disconnect() + teleop.disconnect() + dataset.push_to_hub() + + +if __name__ == "__main__": + main() ``` @@ -348,7 +372,7 @@ The `record` function provides a suite of tools for capturing and managing data ##### 2. Checkpointing and Resuming - Checkpoints are automatically created during recording. -- If an issue occurs, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset ! +- If an issue occurs or you want to record additional episodes in the same dataset, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset! Make sure that you also set `--dataset.root="local_path"`, it's a local path to save the new part of the dataset and is required to resume. - To start recording from scratch, **manually delete** the dataset directory. ##### 3. Recording Parameters @@ -422,7 +446,7 @@ from lerobot.utils.utils import log_say episode_idx = 0 -robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm") +robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem5AB90687491", id="my_follower_arm") robot = SO100Follower(robot_config) robot.connect() @@ -490,6 +514,83 @@ Additionally you can provide extra `tags` or specify a `license` for your model If your local computer doesn't have a powerful GPU you could utilize Google Colab to train your model by following the [ACT training notebook](./notebooks#training-act). +#### Train using Hugging Face Jobs + +Hugging Face jobs let's you easily select hardware and run the training in the cloud. So if you don't have a powerful GPU or you need more VRAM or just want to train a model much faster use HF Jobs! It's pay as you go and you simply pay for each second of use, you can see the pricing and additional information [here](https://huggingface.co/docs/hub/jobs). + +To run the training use this command: + + + +```bash +hf jobs run \ + --flavor a10g-small \ + --timeout 4h \ + --secrets HF_TOKEN \ + huggingface/lerobot-gpu:latest \ + -- \ + python -m lerobot.scripts.lerobot_train \ + --dataset.repo_id=username/dataset \ + --policy.type=act \ + --steps=5000 \ + --batch_size=16 \ + --policy.device=cuda \ + --policy.repo_id=username/your_policy \ + --log_freq=100 +``` + + + + +```python +from huggingface_hub import run_job, get_token + +run_name = "act_so101_hf_jobs" +dataset_id = "username/dataset" +user_hub_id = "username" + +command_args = [ + "python", "-m", "lerobot.scripts.lerobot_train", + "--dataset.repo_id", dataset_id, + "--policy.type", "act", + "--steps", "5000", + "--batch_size", "16", + "--num_workers", "4", + "--policy.device", "cuda", + "--log_freq", "100", + "--save_freq", "1000", + "--save_checkpoint", "true", + "--wandb.enable", "false", + "--policy.repo_id", f"{user_hub_id}/{run_name}" +] + +print(f"Submitting job '{run_name}' to Hugging Face Infrastructure...") + +job_info = run_job( + image="huggingface/lerobot-gpu:latest", + command=command_args, + flavor="a10g-small", + timeout="4h", + secrets={"HF_TOKEN": get_token()} +) + +print("\nšŸš€ Job successfully launched!") +print(f"šŸ”¹ Job ID: {job_info.id}") +print(f"šŸ”— Live UI Dashboard & Logs: {job_info.url}") +``` + + + + + +You can modify the `--flavor` to use different hardware, for example: `t4-small`, `a100-large`, `h200`. Use `hf jobs hardware` to see the full list with pricing. +Depending on the model you want to train and the hardware you selected you can also modify the `--batch_size` and `--number_of_workers`. +For longer training sessions increase the timeout. + +Once the training is started you can go to [Jobs](https://huggingface.co/settings/jobs) and see if your jobs is running as well as all the outputs. Sometimes it takes a few minutes to schedule your job so be patient. + +After training the model will be pushed to hub and you can use it as any other model with LeRobot. + #### Upload policy checkpoints Once training is done, upload the latest checkpoint with: diff --git a/docs/source/smolvla.mdx b/docs/source/smolvla.mdx index 6c63c5d11..e28270c9b 100644 --- a/docs/source/smolvla.mdx +++ b/docs/source/smolvla.mdx @@ -97,22 +97,22 @@ Similarly for when recording an episode, it is recommended that you are logged i Once you are logged in, you can run inference in your setup by doing: ```bash -lerobot-record \ +lerobot-rollout \ + --strategy.type=base \ --robot.type=so101_follower \ --robot.port=/dev/ttyACM0 \ # <- Use your port --robot.id=my_blue_follower_arm \ # <- Use your robot id --robot.cameras="{ front: {type: opencv, index_or_path: 8, width: 640, height: 480, fps: 30}}" \ # <- Use your cameras - --dataset.single_task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording - --dataset.repo_id=${HF_USER}/eval_DATASET_NAME_test \ # <- This will be the dataset name on HF Hub - --dataset.episode_time_s=50 \ - --dataset.num_episodes=10 \ - --dataset.streaming_encoding=true \ - --dataset.encoder_threads=2 \ - # --dataset.camera_encoder.vcodec=auto \ + --task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording + # <- RTC optional, use when running on low power hardware \ + # --inference.type=rtc \ + # --inference.rtc.execution_horizon=10 \ + # --inference.rtc.max_guidance_weight=10.0 \ # <- Teleop optional if you want to teleoperate in between episodes \ # --teleop.type=so100_leader \ # --teleop.port=/dev/ttyACM0 \ # --teleop.id=my_red_leader_arm \ + # --display_data=true #optional use if you want to see the camera stream \ --policy.path=HF_USER/FINETUNE_MODEL_NAME # <- Use your fine-tuned model ``` From b74a551d38f6cf33ddde8e55b0c6f5a9b0c42e85 Mon Sep 17 00:00:00 2001 From: Haoming Song Date: Fri, 22 May 2026 16:29:34 +0800 Subject: [PATCH 3/7] fix(pi0, pi05): stabilize torch.compile and expand test coverage (#3610) * chore(gr00t): sync with #3606 for fixing gr00t config crash * fix(pi0&pi05): fix graph break caused by deepcopy of past_key_values in sample_actions * fix(pi0&pi05): fix frequent recompile caused by compute_layer_complete * feat(test): add compile test and benchamrk for pi0 and pi05 * feat(test): add comprehensive testing for pi0 and pi05. Including processor, forward, sample action, etc. --- src/lerobot/policies/groot/groot_n1.py | 24 +- src/lerobot/policies/pi0/modeling_pi0.py | 56 +- src/lerobot/policies/pi05/modeling_pi05.py | 56 +- .../pi0_pi05/openpi_pytorch/__init__.py | 1 + .../policies/pi0_pi05/openpi_pytorch/gemma.py | 22 + .../pi0_pi05/openpi_pytorch/gemma_pytorch.py | 300 +++++++++++ .../pi0_pi05/openpi_pytorch/image_tools.py | 79 +++ .../pi0_pi05/openpi_pytorch/pi0_pytorch.py | 471 +++++++++++++++++ .../openpi_pytorch/preprocessing_pytorch.py | 179 +++++++ tests/policies/pi0_pi05/test_pi05_compile.py | 101 ++++ .../pi0_pi05/test_pi05_original_vs_lerobot.py | 485 ++++++------------ tests/policies/pi0_pi05/test_pi0_compile.py | 99 ++++ .../pi0_pi05/test_pi0_original_vs_lerobot.py | 479 ++++++----------- tests/policies/pi0_pi05/utils/__init__.py | 1 + .../policies/pi0_pi05/utils/openpi_parity.py | 291 +++++++++++ .../policies/pi0_pi05/utils/torch_compile.py | 207 ++++++++ tests/processor/test_pi05_processor.py | 155 ++++++ tests/processor/test_pi0_processor.py | 156 ++++++ 18 files changed, 2463 insertions(+), 699 deletions(-) create mode 100644 tests/policies/pi0_pi05/openpi_pytorch/__init__.py create mode 100644 tests/policies/pi0_pi05/openpi_pytorch/gemma.py create mode 100644 tests/policies/pi0_pi05/openpi_pytorch/gemma_pytorch.py create mode 100644 tests/policies/pi0_pi05/openpi_pytorch/image_tools.py create mode 100644 tests/policies/pi0_pi05/openpi_pytorch/pi0_pytorch.py create mode 100644 tests/policies/pi0_pi05/openpi_pytorch/preprocessing_pytorch.py create mode 100644 tests/policies/pi0_pi05/test_pi05_compile.py create mode 100644 tests/policies/pi0_pi05/test_pi0_compile.py create mode 100644 tests/policies/pi0_pi05/utils/__init__.py create mode 100644 tests/policies/pi0_pi05/utils/openpi_parity.py create mode 100644 tests/policies/pi0_pi05/utils/torch_compile.py create mode 100644 tests/processor/test_pi05_processor.py create mode 100644 tests/processor/test_pi0_processor.py diff --git a/src/lerobot/policies/groot/groot_n1.py b/src/lerobot/policies/groot/groot_n1.py index 381c5fbd6..c9110301f 100644 --- a/src/lerobot/policies/groot/groot_n1.py +++ b/src/lerobot/policies/groot/groot_n1.py @@ -14,7 +14,7 @@ # limitations under the License. from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import numpy as np import torch @@ -26,9 +26,14 @@ from lerobot.utils.import_utils import _transformers_available # Conditional import for type checking and lazy loading if TYPE_CHECKING or _transformers_available: + from huggingface_hub.dataclasses import strict from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel from transformers.feature_extraction_utils import BatchFeature else: + + def strict(cls): + return cls + AutoConfig = None AutoModel = None PretrainedConfig = object @@ -173,19 +178,20 @@ N_COLOR_CHANNELS = 3 # config +@strict class GR00TN15Config(PretrainedConfig): model_type = "gr00t_n1_5" - backbone_cfg: dict - action_head_cfg: dict - action_horizon: int - action_dim: int + backbone_cfg: dict[str, Any] | None = None + action_head_cfg: dict[str, Any] | None = None + action_horizon: int = 0 + action_dim: int = 0 compute_dtype: str = "float32" - def __init__(self, **kwargs): - super().__init__(**kwargs) - for key, value in kwargs.items(): - setattr(self, key, value) + def __post_init__(self, **kwargs): + self.backbone_cfg = {} if self.backbone_cfg is None else self.backbone_cfg + self.action_head_cfg = {} if self.action_head_cfg is None else self.action_head_cfg + super().__post_init__(**kwargs) # real model diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index 510af0796..f6f4212fb 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -15,7 +15,6 @@ # limitations under the License. import builtins -import copy import logging import math from collections import deque @@ -30,6 +29,7 @@ from lerobot.utils.import_utils import _transformers_available, require_package # Conditional import for type checking and lazy loading if TYPE_CHECKING or _transformers_available: + from transformers.cache_utils import DynamicCache from transformers.models.auto import CONFIG_MAPPING from transformers.models.gemma import modeling_gemma @@ -41,6 +41,7 @@ if TYPE_CHECKING or _transformers_available: ) else: CONFIG_MAPPING = None + DynamicCache = None modeling_gemma = None PiGemmaForCausalLM = None _gated_residual = None @@ -141,6 +142,15 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` ( return att_2d_masks & pad_2d_masks +def clone_past_key_values(past_key_values): + """Clone the DynamicCache returned by prefix prefill for compiled denoising.""" + return DynamicCache( + tuple( + (keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values + ) + ) + + def pad_vector(vector, new_dim): """Pad the last dimension of a vector to new_dim with zeros. @@ -227,16 +237,13 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) # Define the complete layer computation function for gradient checkpointing -def compute_layer_complete( - layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert -): - models = [paligemma.model.language_model, gemma_expert.model] +def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb): query_states = [] key_states = [] value_states = [] gates = [] for i, hidden_states in enumerate(inputs_embeds): - layer = models[i].layers[layer_idx] + layer = layers[i] hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i]) gates.append(gate) input_shape = hidden_states.shape[:-1] @@ -258,15 +265,16 @@ def compute_layer_complete( device=query_states.device, dtype=query_states.dtype, ) - cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids) + cos, sin = rotary_emb(dummy_tensor, position_ids) query_states, key_states = modeling_gemma.apply_rotary_pos_emb( query_states, key_states, cos, sin, unsqueeze_dim=1 ) batch_size = query_states.shape[0] - scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling + paligemma_layer = layers[0] + scaling = paligemma_layer.self_attn.scaling # Attention computation att_output, _ = modeling_gemma.eager_attention_forward( - paligemma.model.language_model.layers[layer_idx].self_attn, + paligemma_layer.self_attn, query_states, key_states, value_states, @@ -274,13 +282,13 @@ def compute_layer_complete( scaling, ) # Get head_dim from the current layer, not from the model - head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim + head_dim = paligemma_layer.self_attn.head_dim att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim) # Process layer outputs outputs_embeds = [] start_pos = 0 for i, hidden_states in enumerate(inputs_embeds): - layer = models[i].layers[layer_idx] + layer = layers[i] end_pos = start_pos + hidden_states.shape[1] if att_output.dtype != layer.self_attn.o_proj.weight.dtype: att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) @@ -488,8 +496,9 @@ class PaliGemmaWithExpertModel( prefix_output = None prefix_past_key_values = None else: - models = [self.paligemma.model.language_model, self.gemma_expert.model] - num_layers = self.paligemma.config.text_config.num_hidden_layers + paligemma_layers = self.paligemma.model.language_model.layers + gemma_expert_layers = self.gemma_expert.model.layers + rotary_emb = self.paligemma.model.language_model.rotary_emb # Check if gradient checkpointing is enabled for any of the models use_gradient_checkpointing = ( @@ -499,36 +508,39 @@ class PaliGemmaWithExpertModel( ) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training) # Process all layers with gradient checkpointing if enabled - for layer_idx in range(num_layers): + for layers in zip(paligemma_layers, gemma_expert_layers, strict=True): if use_gradient_checkpointing: inputs_embeds = torch.utils.checkpoint.checkpoint( compute_layer_complete, - layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, use_reentrant=False, preserve_rng_state=False, - paligemma=self.paligemma, - gemma_expert=self.gemma_expert, + layers=layers, + rotary_emb=rotary_emb, ) else: inputs_embeds = compute_layer_complete( - layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, - paligemma=self.paligemma, - gemma_expert=self.gemma_expert, + layers=layers, + rotary_emb=rotary_emb, ) # final norm + final_norms = ( + self.paligemma.model.language_model.norm, + self.gemma_expert.model.norm, + ) + def compute_final_norms(inputs_embeds, adarms_cond): outputs_embeds = [] for i, hidden_states in enumerate(inputs_embeds): - out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i]) + out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i]) outputs_embeds.append(out_emb) return outputs_embeds @@ -907,7 +919,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks) self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001 - past_key_values = copy.deepcopy(past_key_values) + past_key_values = clone_past_key_values(past_key_values) outputs_embeds, _ = self.paligemma_with_expert.forward( attention_mask=full_att_2d_masks_4d, position_ids=position_ids, diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index bdaf01f2c..aabd04c6f 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -15,7 +15,6 @@ # limitations under the License. import builtins -import copy import logging import math from collections import deque @@ -30,6 +29,7 @@ from lerobot.utils.import_utils import _transformers_available, require_package # Conditional import for type checking and lazy loading if TYPE_CHECKING or _transformers_available: + from transformers.cache_utils import DynamicCache from transformers.models.auto import CONFIG_MAPPING from transformers.models.gemma import modeling_gemma @@ -41,6 +41,7 @@ if TYPE_CHECKING or _transformers_available: ) else: CONFIG_MAPPING = None + DynamicCache = None modeling_gemma = None PiGemmaForCausalLM = None _gated_residual = None @@ -138,6 +139,15 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` ( return att_2d_masks & pad_2d_masks +def clone_past_key_values(past_key_values): + """Clone the DynamicCache returned by prefix prefill for compiled denoising.""" + return DynamicCache( + tuple( + (keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values + ) + ) + + def pad_vector(vector, new_dim): """Pad the last dimension of a vector to new_dim with zeros. @@ -224,16 +234,13 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) # Define the complete layer computation function for gradient checkpointing -def compute_layer_complete( - layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert -): - models = [paligemma.model.language_model, gemma_expert.model] +def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb): query_states = [] key_states = [] value_states = [] gates = [] for i, hidden_states in enumerate(inputs_embeds): - layer = models[i].layers[layer_idx] + layer = layers[i] hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i]) gates.append(gate) input_shape = hidden_states.shape[:-1] @@ -255,15 +262,16 @@ def compute_layer_complete( device=query_states.device, dtype=query_states.dtype, ) - cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids) + cos, sin = rotary_emb(dummy_tensor, position_ids) query_states, key_states = modeling_gemma.apply_rotary_pos_emb( query_states, key_states, cos, sin, unsqueeze_dim=1 ) batch_size = query_states.shape[0] - scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling + paligemma_layer = layers[0] + scaling = paligemma_layer.self_attn.scaling # Attention computation att_output, _ = modeling_gemma.eager_attention_forward( - paligemma.model.language_model.layers[layer_idx].self_attn, + paligemma_layer.self_attn, query_states, key_states, value_states, @@ -271,13 +279,13 @@ def compute_layer_complete( scaling, ) # Get head_dim from the current layer, not from the model - head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim + head_dim = paligemma_layer.self_attn.head_dim att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim) # Process layer outputs outputs_embeds = [] start_pos = 0 for i, hidden_states in enumerate(inputs_embeds): - layer = models[i].layers[layer_idx] + layer = layers[i] end_pos = start_pos + hidden_states.shape[1] if att_output.dtype != layer.self_attn.o_proj.weight.dtype: att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) @@ -485,8 +493,9 @@ class PaliGemmaWithExpertModel( prefix_output = None prefix_past_key_values = None else: - models = [self.paligemma.model.language_model, self.gemma_expert.model] - num_layers = self.paligemma.config.text_config.num_hidden_layers + paligemma_layers = self.paligemma.model.language_model.layers + gemma_expert_layers = self.gemma_expert.model.layers + rotary_emb = self.paligemma.model.language_model.rotary_emb # Check if gradient checkpointing is enabled for any of the models use_gradient_checkpointing = ( @@ -496,36 +505,39 @@ class PaliGemmaWithExpertModel( ) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training) # Process all layers with gradient checkpointing if enabled - for layer_idx in range(num_layers): + for layers in zip(paligemma_layers, gemma_expert_layers, strict=True): if use_gradient_checkpointing: inputs_embeds = torch.utils.checkpoint.checkpoint( compute_layer_complete, - layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, use_reentrant=False, preserve_rng_state=False, - paligemma=self.paligemma, - gemma_expert=self.gemma_expert, + layers=layers, + rotary_emb=rotary_emb, ) else: inputs_embeds = compute_layer_complete( - layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, - paligemma=self.paligemma, - gemma_expert=self.gemma_expert, + layers=layers, + rotary_emb=rotary_emb, ) # final norm + final_norms = ( + self.paligemma.model.language_model.norm, + self.gemma_expert.model.norm, + ) + def compute_final_norms(inputs_embeds, adarms_cond): outputs_embeds = [] for i, hidden_states in enumerate(inputs_embeds): - out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i]) + out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i]) outputs_embeds.append(out_emb) return outputs_embeds @@ -880,7 +892,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks) self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001 - past_key_values = copy.deepcopy(past_key_values) + past_key_values = clone_past_key_values(past_key_values) outputs_embeds, _ = self.paligemma_with_expert.forward( attention_mask=full_att_2d_masks_4d, position_ids=position_ids, diff --git a/tests/policies/pi0_pi05/openpi_pytorch/__init__.py b/tests/policies/pi0_pi05/openpi_pytorch/__init__.py new file mode 100644 index 000000000..e1cdcf3fc --- /dev/null +++ b/tests/policies/pi0_pi05/openpi_pytorch/__init__.py @@ -0,0 +1 @@ +"""Lightweight vendored OpenPI PyTorch modules for PI0/PI05 parity tests.""" diff --git a/tests/policies/pi0_pi05/openpi_pytorch/gemma.py b/tests/policies/pi0_pi05/openpi_pytorch/gemma.py new file mode 100644 index 000000000..2210f5c01 --- /dev/null +++ b/tests/policies/pi0_pi05/openpi_pytorch/gemma.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass + + +@dataclass +class Config: + width: int + depth: int + mlp_dim: int + num_heads: int + num_kv_heads: int + head_dim: int + + +def get_config(variant: str) -> Config: + """Return the Gemma shape config needed by the OpenPI PyTorch model.""" + if variant == "dummy": + return Config(width=64, depth=4, mlp_dim=128, num_heads=8, num_kv_heads=1, head_dim=16) + if variant == "gemma_300m": + return Config(width=1024, depth=18, mlp_dim=4096, num_heads=8, num_kv_heads=1, head_dim=256) + if variant == "gemma_2b": + return Config(width=2048, depth=18, mlp_dim=16_384, num_heads=8, num_kv_heads=1, head_dim=256) + raise ValueError(f"Unknown variant: {variant}") diff --git a/tests/policies/pi0_pi05/openpi_pytorch/gemma_pytorch.py b/tests/policies/pi0_pi05/openpi_pytorch/gemma_pytorch.py new file mode 100644 index 000000000..48f07cd35 --- /dev/null +++ b/tests/policies/pi0_pi05/openpi_pytorch/gemma_pytorch.py @@ -0,0 +1,300 @@ +from typing import Literal + +import torch +from torch import nn +from transformers.models.auto import CONFIG_MAPPING +from transformers.models.gemma import modeling_gemma + +from lerobot.policies.pi_gemma import ( + PaliGemmaForConditionalGenerationWithPiGemma, + PiGemmaForCausalLM, + _gated_residual, + layernorm_forward, +) + + +class PaliGemmaWithExpertModel(nn.Module): + def __init__( + self, + vlm_config, + action_expert_config, + use_adarms=None, + precision: Literal["bfloat16", "float32"] = "bfloat16", + ): + if use_adarms is None: + use_adarms = [False, False] + super().__init__() + + vlm_config_hf = CONFIG_MAPPING["paligemma"]() + vlm_config_hf._vocab_size = 257152 # noqa: SLF001 + vlm_config_hf.image_token_index = 257152 + vlm_config_hf.text_config.hidden_size = vlm_config.width + vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim + vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads + vlm_config_hf.text_config.head_dim = vlm_config.head_dim + vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth + vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads + vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh" + vlm_config_hf.text_config.dtype = "float32" + vlm_config_hf.text_config.vocab_size = 257152 + vlm_config_hf.text_config.use_adarms = use_adarms[0] + vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None + vlm_config_hf.vision_config.intermediate_size = 4304 + vlm_config_hf.vision_config.projection_dim = 2048 + vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast" + vlm_config_hf.vision_config.dtype = "float32" + + action_expert_config_hf = CONFIG_MAPPING["gemma"]( + head_dim=action_expert_config.head_dim, + hidden_size=action_expert_config.width, + intermediate_size=action_expert_config.mlp_dim, + num_attention_heads=action_expert_config.num_heads, + num_hidden_layers=action_expert_config.depth, + num_key_value_heads=action_expert_config.num_kv_heads, + vocab_size=257152, + hidden_activation="gelu_pytorch_tanh", + dtype="float32", + use_adarms=use_adarms[1], + adarms_cond_dim=action_expert_config.width if use_adarms[1] else None, + ) + + self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf) + self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf) + self.gemma_expert.model.embed_tokens = None + + self.to_bfloat16_for_selected_params(precision) + + def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"): + if precision == "bfloat16": + self.to(dtype=torch.bfloat16) + elif precision == "float32": + self.to(dtype=torch.float32) + return + else: + raise ValueError(f"Invalid precision: {precision}") + + params_to_keep_float32 = [ + "vision_tower", + "multi_modal_projector", + "input_layernorm", + "post_attention_layernorm", + "model.norm", + ] + + for name, param in self.named_parameters(): + if any(selector in name for selector in params_to_keep_float32): + param.data = param.data.to(dtype=torch.float32) + + def embed_image(self, image: torch.Tensor): + # Transformers 5.4 no longer divides PaliGemma image features by sqrt(hidden_size), + # so the upstream helper now matches OpenPI's patched PaliGemma image-scale semantics. + # See https://github.com/huggingface/transformers/pull/44432/changes#diff-c916907e7e52ac85ee1a1527560eae4656cd6c76141ceb1fe3da61bd5f697d2a + out_dtype = image.dtype + if image.dtype != torch.float32: + image = image.to(torch.float32) + image_outputs = self.paligemma.model.get_image_features(image) + features = image_outputs.pooler_output + if features.dtype != out_dtype: + features = features.to(out_dtype) + return features + + def embed_language_tokens(self, tokens: torch.Tensor): + return self.paligemma.model.language_model.get_input_embeddings()(tokens) + + def forward( + self, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: list[torch.FloatTensor] | None = None, + use_cache: bool | None = None, + adarms_cond: list[torch.Tensor] | None = None, + ): + if adarms_cond is None: + adarms_cond = [None, None] + if inputs_embeds[1] is None: + prefix_output = self.paligemma.model.language_model.forward( + inputs_embeds=inputs_embeds[0], + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + adarms_cond=adarms_cond[0] if adarms_cond is not None else None, + ) + prefix_past_key_values = prefix_output.past_key_values + prefix_output = prefix_output.last_hidden_state + suffix_output = None + elif inputs_embeds[0] is None: + suffix_output = self.gemma_expert.model.forward( + inputs_embeds=inputs_embeds[1], + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + adarms_cond=adarms_cond[1] if adarms_cond is not None else None, + ) + suffix_output = suffix_output.last_hidden_state + prefix_output = None + prefix_past_key_values = None + else: + models = [self.paligemma.model.language_model, self.gemma_expert.model] + num_layers = self.paligemma.config.text_config.num_hidden_layers + + # Check if gradient checkpointing is enabled for any of the models + use_gradient_checkpointing = ( + hasattr(self.gemma_expert.model, "gradient_checkpointing") + and self.gemma_expert.model.gradient_checkpointing + and self.training + ) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training) + + # Force enable gradient checkpointing if we're in training mode and the model supports it + if self.training and hasattr(self.gemma_expert.model, "gradient_checkpointing"): + if not self.gemma_expert.model.gradient_checkpointing: + print("Forcing gradient checkpointing to be enabled for Gemma expert model") + self.gemma_expert.model.gradient_checkpointing = True + use_gradient_checkpointing = True + + # Debug gradient checkpointing status + if hasattr(self, "_debug_gc_printed") and not self._debug_gc_printed: + print(f"Gemma expert model gradient checkpointing: {use_gradient_checkpointing}") + print(f"Model training mode: {self.training}") + print( + f"Gemma expert model has gradient_checkpointing attr: {hasattr(self.gemma_expert.model, 'gradient_checkpointing')}" + ) + if hasattr(self.gemma_expert.model, "gradient_checkpointing"): + print( + f"Gemma expert model gradient_checkpointing value: {self.gemma_expert.model.gradient_checkpointing}" + ) + self._debug_gc_printed = True + + # Define the complete layer computation function for gradient checkpointing + def compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond): + models = [self.paligemma.model.language_model, self.gemma_expert.model] + + query_states = [] + key_states = [] + value_states = [] + gates = [] + for i, hidden_states in enumerate(inputs_embeds): + layer = models[i].layers[layer_idx] + hidden_states, gate = layernorm_forward( + layer.input_layernorm, hidden_states, adarms_cond[i] + ) + gates.append(gate) + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) + query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + query_states.append(query_state) + key_states.append(key_state) + value_states.append(value_state) + + # Concatenate and process attention + query_states = torch.cat(query_states, dim=2) + key_states = torch.cat(key_states, dim=2) + value_states = torch.cat(value_states, dim=2) + + dummy_tensor = torch.zeros( + query_states.shape[0], + query_states.shape[2], + query_states.shape[-1], + device=query_states.device, + dtype=query_states.dtype, + ) + cos, sin = self.paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids) + query_states, key_states = modeling_gemma.apply_rotary_pos_emb( + query_states, key_states, cos, sin, unsqueeze_dim=1 + ) + + batch_size = query_states.shape[0] + scaling = self.paligemma.model.language_model.layers[layer_idx].self_attn.scaling + + # Attention computation + att_output, _ = modeling_gemma.eager_attention_forward( + self.paligemma.model.language_model.layers[layer_idx].self_attn, + query_states, + key_states, + value_states, + attention_mask, + scaling, + ) + # Get head_dim from the current layer, not from the model + head_dim = self.paligemma.model.language_model.layers[layer_idx].self_attn.head_dim + att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim) + + # Process layer outputs + outputs_embeds = [] + start_pos = 0 + for i, hidden_states in enumerate(inputs_embeds): + layer = models[i].layers[layer_idx] + end_pos = start_pos + hidden_states.shape[1] + + if att_output.dtype != layer.self_attn.o_proj.weight.dtype: + att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) + out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos]) + + # first residual + out_emb = _gated_residual(hidden_states, out_emb, gates[i]) + after_first_residual = out_emb.clone() + out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i]) + # Convert to bfloat16 if the next layer (mlp) uses bfloat16 + if layer.mlp.up_proj.weight.dtype == torch.bfloat16: + out_emb = out_emb.to(dtype=torch.bfloat16) + + out_emb = layer.mlp(out_emb) + # second residual + out_emb = _gated_residual(after_first_residual, out_emb, gate) + outputs_embeds.append(out_emb) + start_pos = end_pos + + return outputs_embeds + + # Process all layers with gradient checkpointing if enabled + for layer_idx in range(num_layers): + if use_gradient_checkpointing: + inputs_embeds = torch.utils.checkpoint.checkpoint( + compute_layer_complete, + layer_idx, + inputs_embeds, + attention_mask, + position_ids, + adarms_cond, + use_reentrant=False, + preserve_rng_state=False, + ) + else: + inputs_embeds = compute_layer_complete( + layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond + ) + + # Old code removed - now using compute_layer_complete function above + + # final norm + # Define final norm computation function for gradient checkpointing + def compute_final_norms(inputs_embeds, adarms_cond): + outputs_embeds = [] + for i, hidden_states in enumerate(inputs_embeds): + out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i]) + outputs_embeds.append(out_emb) + return outputs_embeds + + # Apply gradient checkpointing to final norm if enabled + if use_gradient_checkpointing: + outputs_embeds = torch.utils.checkpoint.checkpoint( + compute_final_norms, + inputs_embeds, + adarms_cond, + use_reentrant=False, + preserve_rng_state=False, + ) + else: + outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond) + + prefix_output = outputs_embeds[0] + suffix_output = outputs_embeds[1] + prefix_past_key_values = None + + return [prefix_output, suffix_output], prefix_past_key_values diff --git a/tests/policies/pi0_pi05/openpi_pytorch/image_tools.py b/tests/policies/pi0_pi05/openpi_pytorch/image_tools.py new file mode 100644 index 000000000..a459f7859 --- /dev/null +++ b/tests/policies/pi0_pi05/openpi_pytorch/image_tools.py @@ -0,0 +1,79 @@ +import torch +import torch.nn.functional as F # noqa: N812 + + +def resize_with_pad_torch( + images: torch.Tensor, + height: int, + width: int, + mode: str = "bilinear", +) -> torch.Tensor: + """PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion + by padding with black. If the image is float32, it must be in the range [-1, 1]. + + Args: + images: Tensor of shape [*b, h, w, c] or [*b, c, h, w] + height: Target height + width: Target width + mode: Interpolation mode ('bilinear', 'nearest', etc.) + + Returns: + Resized and padded tensor with same shape format as input + """ + # Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w] + if images.shape[-1] <= 4: # Assume channels-last format + channels_last = True + # Convert to channels-first for torch operations + if images.dim() == 3: + images = images.unsqueeze(0) # Add batch dimension + images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w] + else: + channels_last = False + if images.dim() == 3: + images = images.unsqueeze(0) # Add batch dimension + + batch_size, channels, cur_height, cur_width = images.shape + + # Calculate resize ratio + ratio = max(cur_width / width, cur_height / height) + resized_height = int(cur_height / ratio) + resized_width = int(cur_width / ratio) + + # Resize + resized_images = F.interpolate( + images, + size=(resized_height, resized_width), + mode=mode, + align_corners=False if mode == "bilinear" else None, + ) + + # Handle dtype-specific clipping + if images.dtype == torch.uint8: + resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8) + elif images.dtype == torch.float32: + resized_images = resized_images.clamp(-1.0, 1.0) + else: + raise ValueError(f"Unsupported image dtype: {images.dtype}") + + # Calculate padding + pad_h0, remainder_h = divmod(height - resized_height, 2) + pad_h1 = pad_h0 + remainder_h + pad_w0, remainder_w = divmod(width - resized_width, 2) + pad_w1 = pad_w0 + remainder_w + + # Pad + constant_value = 0 if images.dtype == torch.uint8 else -1.0 + padded_images = F.pad( + resized_images, + (pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom + mode="constant", + value=constant_value, + ) + + # Convert back to original format if needed + if channels_last: + padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] + if batch_size == 1 and images.shape[0] == 1: + padded_images = padded_images.squeeze(0) # Remove batch dimension if it was added + + return padded_images diff --git a/tests/policies/pi0_pi05/openpi_pytorch/pi0_pytorch.py b/tests/policies/pi0_pi05/openpi_pytorch/pi0_pytorch.py new file mode 100644 index 000000000..77f1342d9 --- /dev/null +++ b/tests/policies/pi0_pi05/openpi_pytorch/pi0_pytorch.py @@ -0,0 +1,471 @@ +import copy +import logging +import math + +import torch +import torch.nn.functional as F # noqa: N812 +from torch import Tensor, nn + +import tests.policies.pi0_pi05.openpi_pytorch.gemma as _gemma +from tests.policies.pi0_pi05.openpi_pytorch import preprocessing_pytorch as _preprocessing +from tests.policies.pi0_pi05.openpi_pytorch.gemma_pytorch import PaliGemmaWithExpertModel + + +def get_safe_dtype(target_dtype, device_type): + """Get a safe dtype for the given device type.""" + if device_type == "cpu": + # CPU doesn't support bfloat16, use float32 instead + if target_dtype == torch.bfloat16: + return torch.float32 + if target_dtype == torch.float64: + return torch.float64 + return target_dtype + + +def create_sinusoidal_pos_embedding( + time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu" +) -> Tensor: + """Computes sine-cosine positional embedding vectors for scalar positions.""" + if dimension % 2 != 0: + raise ValueError(f"dimension ({dimension}) must be divisible by 2") + + if time.ndim != 1: + raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") + + dtype = get_safe_dtype(torch.float64, device.type) + fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device) + period = min_period * (max_period / min_period) ** fraction + + # Compute the outer product + scaling_factor = 1.0 / period * 2 * math.pi + sin_input = scaling_factor[None, :] * time[:, None] + return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) + + +def sample_beta(alpha, beta, bsize, device): + alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device) + beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device) + dist = torch.distributions.Beta(alpha_t, beta_t) + return dist.sample((bsize,)) + + +def make_att_2d_masks(pad_masks, att_masks): + """Copied from big_vision. + + Tokens can attend to valid inputs tokens which have a cumulative mask_ar + smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to + setup several types of attention, for example: + + [[1 1 1 1 1 1]]: pure causal attention. + + [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between + themselves and the last 3 tokens have a causal attention. The first + entry could also be a 1 without changing behaviour. + + [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a + block can attend all previous blocks and all tokens on the same block. + + Args: + input_mask: bool[B, N] true if its part of the input, false if padding. + mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on + it and 0 where it shares the same attention mask as the previous token. + """ + if att_masks.ndim != 2: + raise ValueError(att_masks.ndim) + if pad_masks.ndim != 2: + raise ValueError(pad_masks.ndim) + + cumsum = torch.cumsum(att_masks, dim=1) + att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None] + pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] + return att_2d_masks & pad_2d_masks + + +class PI0Pytorch(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.pi05 = config.pi05 + + paligemma_config = _gemma.get_config(config.paligemma_variant) + action_expert_config = _gemma.get_config(config.action_expert_variant) + + self.paligemma_with_expert = PaliGemmaWithExpertModel( + paligemma_config, + action_expert_config, + use_adarms=[False, True] if self.pi05 else [False, False], + precision=config.dtype, + ) + + self.action_in_proj = nn.Linear(config.action_dim, action_expert_config.width) + self.action_out_proj = nn.Linear(action_expert_config.width, config.action_dim) + + if self.pi05: + self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width) + self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width) + else: + self.state_proj = nn.Linear(config.action_dim, action_expert_config.width) + self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width) + self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width) + + torch.set_float32_matmul_precision("high") + if config.pytorch_compile_mode is not None: + self.sample_actions = torch.compile(self.sample_actions, mode=config.pytorch_compile_mode) + + # Initialize gradient checkpointing flag + self.gradient_checkpointing_enabled = False + + # The upstream OpenPI module verifies a site-package Transformers patch here. + # This vendored test copy instead routes through LeRobot's local PiGemma compatibility layer. + + def gradient_checkpointing_enable(self): + """Enable gradient checkpointing for memory optimization.""" + self.gradient_checkpointing_enabled = True + self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True + self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True + self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True + + logging.info("Enabled gradient checkpointing for PI0Pytorch model") + + def gradient_checkpointing_disable(self): + """Disable gradient checkpointing.""" + self.gradient_checkpointing_enabled = False + self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False + self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False + self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False + + logging.info("Disabled gradient checkpointing for PI0Pytorch model") + + def is_gradient_checkpointing_enabled(self): + """Check if gradient checkpointing is enabled.""" + return self.gradient_checkpointing_enabled + + def _apply_checkpoint(self, func, *args, **kwargs): + """Helper method to apply gradient checkpointing if enabled.""" + if self.gradient_checkpointing_enabled and self.training: + return torch.utils.checkpoint.checkpoint( + func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs + ) + return func(*args, **kwargs) + + def _prepare_attention_masks_4d(self, att_2d_masks): + """Helper method to prepare 4D attention masks for transformer.""" + att_2d_masks_4d = att_2d_masks[:, None, :, :] + return torch.where(att_2d_masks_4d, 0.0, -2.3819763e38) + + def _preprocess_observation(self, observation, *, train=True): + """Helper method to preprocess observation.""" + observation = _preprocessing.preprocess_observation_pytorch(observation, train=train) + return ( + list(observation.images.values()), + list(observation.image_masks.values()), + observation.tokenized_prompt, + observation.tokenized_prompt_mask, + observation.state, + ) + + def sample_noise(self, shape, device): + return torch.normal( + mean=0.0, + std=1.0, + size=shape, + dtype=torch.float32, + device=device, + ) + + def sample_time(self, bsize, device): + time_beta = sample_beta(1.5, 1.0, bsize, device) + time = time_beta * 0.999 + 0.001 + return time.to(dtype=torch.float32, device=device) + + def embed_prefix( + self, images, img_masks, lang_tokens, lang_masks + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Embed images with SigLIP and language tokens with embedding layer to prepare + for PaliGemma transformer processing. + """ + embs = [] + pad_masks = [] + att_masks = [] + + # Process images + for img, img_mask in zip(images, img_masks, strict=True): + + def image_embed_func(img): + return self.paligemma_with_expert.embed_image(img) + + img_emb = self._apply_checkpoint(image_embed_func, img) + + bsize, num_img_embs = img_emb.shape[:2] + + embs.append(img_emb) + pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs)) + + # Create attention masks so that image tokens attend to each other + att_masks += [0] * num_img_embs + + # Process language tokens + def lang_embed_func(lang_tokens): + lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) + # Transformers > 5.4 scales Gemma token embeddings inside embed_tokens, matching + # OpenPI's former explicit sqrt(hidden_size) multiply without applying it twice. + # See https://github.com/huggingface/transformers/pull/44432/changes#diff-5f76eac6f18f4b491521314c318a9692318feb4d19228e9576cce7bde4240834 + return lang_emb + + lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens) + + embs.append(lang_emb) + pad_masks.append(lang_masks) + + # full attention between image and language inputs + num_lang_embs = lang_emb.shape[1] + att_masks += [0] * num_lang_embs + + embs = torch.cat(embs, dim=1) + pad_masks = torch.cat(pad_masks, dim=1) + att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) + + # Get batch size from the first dimension of the concatenated tensors + bsize = pad_masks.shape[0] + att_masks = att_masks[None, :].expand(bsize, len(att_masks)) + + return embs, pad_masks, att_masks + + def embed_suffix(self, state, noisy_actions, timestep): + """Embed state, noisy_actions, timestep to prepare for Expert Gemma processing.""" + embs = [] + pad_masks = [] + att_masks = [] + + if not self.pi05: + if self.state_proj.weight.dtype == torch.float32: + state = state.to(torch.float32) + + # Embed state + def state_proj_func(state): + return self.state_proj(state) + + state_emb = self._apply_checkpoint(state_proj_func, state) + + embs.append(state_emb[:, None, :]) + bsize = state_emb.shape[0] + device = state_emb.device + + state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device) + pad_masks.append(state_mask) + + # Set attention masks so that image and language inputs do not attend to state or actions + att_masks += [1] + + # Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1] + time_emb = create_sinusoidal_pos_embedding( + timestep, + self.action_in_proj.out_features, + min_period=4e-3, + max_period=4.0, + device=timestep.device, + ) + time_emb = time_emb.type(dtype=timestep.dtype) + + # Fuse timestep + action information using an MLP + def action_proj_func(noisy_actions): + return self.action_in_proj(noisy_actions) + + action_emb = self._apply_checkpoint(action_proj_func, noisy_actions) + + if not self.pi05: + time_emb = time_emb[:, None, :].expand_as(action_emb) + action_time_emb = torch.cat([action_emb, time_emb], dim=2) + + # Apply MLP layers + def mlp_func(action_time_emb): + x = self.action_time_mlp_in(action_time_emb) + x = F.silu(x) # swish == silu + return self.action_time_mlp_out(x) + + action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb) + adarms_cond = None + else: + # time MLP (for adaRMS) + def time_mlp_func(time_emb): + x = self.time_mlp_in(time_emb) + x = F.silu(x) # swish == silu + x = self.time_mlp_out(x) + return F.silu(x) + + time_emb = self._apply_checkpoint(time_mlp_func, time_emb) + action_time_emb = action_emb + adarms_cond = time_emb + + # Add to input tokens + embs.append(action_time_emb) + + bsize, action_time_dim = action_time_emb.shape[:2] + action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device) + pad_masks.append(action_time_mask) + + # Set attention masks so that image, language and state inputs do not attend to action tokens + att_masks += [1] + ([0] * (self.config.action_horizon - 1)) + + embs = torch.cat(embs, dim=1) + pad_masks = torch.cat(pad_masks, dim=1) + att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device) + att_masks = att_masks[None, :].expand(bsize, len(att_masks)) + + return embs, pad_masks, att_masks, adarms_cond + + def forward(self, observation, actions, noise=None, time=None) -> Tensor: + """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)""" + images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation( + observation, train=True + ) + + if noise is None: + noise = self.sample_noise(actions.shape, actions.device) + + if time is None: + time = self.sample_time(actions.shape[0], actions.device) + + time_expanded = time[:, None, None] + x_t = time_expanded * noise + (1 - time_expanded) * actions + u_t = noise - actions + + prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( + images, img_masks, lang_tokens, lang_masks + ) + suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time) + if ( + self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype + == torch.bfloat16 + ): + suffix_embs = suffix_embs.to(dtype=torch.bfloat16) + prefix_embs = prefix_embs.to(dtype=torch.bfloat16) + + pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) + att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) + + att_2d_masks = make_att_2d_masks(pad_masks, att_masks) + position_ids = torch.cumsum(pad_masks, dim=1) - 1 + + # Prepare attention masks + att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks) + + # Apply gradient checkpointing if enabled + def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond): + (_, suffix_out), _ = self.paligemma_with_expert.forward( + attention_mask=att_2d_masks_4d, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, suffix_embs], + use_cache=False, + adarms_cond=[None, adarms_cond], + ) + return suffix_out + + suffix_out = self._apply_checkpoint( + forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond + ) + + suffix_out = suffix_out[:, -self.config.action_horizon :] + suffix_out = suffix_out.to(dtype=torch.float32) + + # Apply gradient checkpointing to final action projection if enabled + def action_out_proj_func(suffix_out): + return self.action_out_proj(suffix_out) + + v_t = self._apply_checkpoint(action_out_proj_func, suffix_out) + + return F.mse_loss(u_t, v_t, reduction="none") + + @torch.no_grad() + def sample_actions(self, device, observation, noise=None, num_steps=10) -> Tensor: + """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)""" + bsize = observation.state.shape[0] + if noise is None: + actions_shape = (bsize, self.config.action_horizon, self.config.action_dim) + noise = self.sample_noise(actions_shape, device) + + images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation( + observation, train=False + ) + + prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( + images, img_masks, lang_tokens, lang_masks + ) + prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) + prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 + + # Compute image and language key value cache + prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks) + self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001 + + _, past_key_values = self.paligemma_with_expert.forward( + attention_mask=prefix_att_2d_masks_4d, + position_ids=prefix_position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, None], + use_cache=True, + ) + + dt = -1.0 / num_steps + dt = torch.tensor(dt, dtype=torch.float32, device=device) + + x_t = noise + time = torch.tensor(1.0, dtype=torch.float32, device=device) + while time >= -dt / 2: + expanded_time = time.expand(bsize) + v_t = self.denoise_step( + state, + prefix_pad_masks, + past_key_values, + x_t, + expanded_time, + ) + + # Euler step - use new tensor assignment instead of in-place operation + x_t = x_t + dt * v_t + time += dt + return x_t + + def denoise_step( + self, + state, + prefix_pad_masks, + past_key_values, + x_t, + timestep, + ): + """Apply one denoising step of the noise `x_t` at a given timestep.""" + suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, timestep) + + suffix_len = suffix_pad_masks.shape[1] + batch_size = prefix_pad_masks.shape[0] + prefix_len = prefix_pad_masks.shape[1] + + prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len) + + suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) + + full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2) + + prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] + position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 + + # Prepare attention masks + full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks) + self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001 + + past_key_values = copy.deepcopy(past_key_values) + outputs_embeds, _ = self.paligemma_with_expert.forward( + attention_mask=full_att_2d_masks_4d, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=[None, suffix_embs], + use_cache=False, + adarms_cond=[None, adarms_cond], + ) + + suffix_out = outputs_embeds[1] + suffix_out = suffix_out[:, -self.config.action_horizon :] + suffix_out = suffix_out.to(dtype=torch.float32) + return self.action_out_proj(suffix_out) diff --git a/tests/policies/pi0_pi05/openpi_pytorch/preprocessing_pytorch.py b/tests/policies/pi0_pi05/openpi_pytorch/preprocessing_pytorch.py new file mode 100644 index 000000000..40df8d947 --- /dev/null +++ b/tests/policies/pi0_pi05/openpi_pytorch/preprocessing_pytorch.py @@ -0,0 +1,179 @@ +import logging +from collections.abc import Sequence + +import torch + +from tests.policies.pi0_pi05.openpi_pytorch import image_tools + +logger = logging.getLogger("openpi") + +# Constants moved from model.py +IMAGE_KEYS = ( + "base_0_rgb", + "left_wrist_0_rgb", + "right_wrist_0_rgb", +) + +IMAGE_RESOLUTION = (224, 224) + + +def preprocess_observation_pytorch( + observation, + *, + train: bool = False, + image_keys: Sequence[str] = IMAGE_KEYS, + image_resolution: tuple[int, int] = IMAGE_RESOLUTION, +): + """Torch.compile-compatible version of preprocess_observation_pytorch with simplified type annotations. + + This function avoids complex type annotations that can cause torch.compile issues. + """ + if not set(image_keys).issubset(observation.images): + raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}") + + batch_shape = observation.state.shape[:-1] + + out_images = {} + for key in image_keys: + image = observation.images[key] + + # TODO: This is a hack to handle both [B, C, H, W] and [B, H, W, C] formats + # Handle both [B, C, H, W] and [B, H, W, C] formats + is_channels_first = image.shape[1] == 3 # Check if channels are in dimension 1 + + if is_channels_first: + # Convert [B, C, H, W] to [B, H, W, C] for processing + image = image.permute(0, 2, 3, 1) + + if image.shape[1:3] != image_resolution: + logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}") + image = image_tools.resize_with_pad_torch(image, *image_resolution) + + if train: + # Convert from [-1, 1] to [0, 1] for PyTorch augmentations + image = image / 2.0 + 0.5 + + # Apply PyTorch-based augmentations + if "wrist" not in key: + # Geometric augmentations for non-wrist cameras + height, width = image.shape[1:3] + + # Random crop and resize + crop_height = int(height * 0.95) + crop_width = int(width * 0.95) + + # Random crop + max_h = height - crop_height + max_w = width - crop_width + if max_h > 0 and max_w > 0: + # Use tensor operations instead of .item() for torch.compile compatibility + start_h = torch.randint(0, max_h + 1, (1,), device=image.device) + start_w = torch.randint(0, max_w + 1, (1,), device=image.device) + image = image[:, start_h : start_h + crop_height, start_w : start_w + crop_width, :] + + # Resize back to original size + image = torch.nn.functional.interpolate( + image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w] + size=(height, width), + mode="bilinear", + align_corners=False, + ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] + + # Random rotation (small angles) + # Use tensor operations instead of .item() for torch.compile compatibility + angle = torch.rand(1, device=image.device) * 10 - 5 # Random angle between -5 and 5 degrees + if torch.abs(angle) > 0.1: # Only rotate if angle is significant + # Convert to radians + angle_rad = angle * torch.pi / 180.0 + + # Create rotation matrix + cos_a = torch.cos(angle_rad) + sin_a = torch.sin(angle_rad) + + # Apply rotation using grid_sample + grid_x = torch.linspace(-1, 1, width, device=image.device) + grid_y = torch.linspace(-1, 1, height, device=image.device) + + # Create meshgrid + grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing="ij") + + # Expand to batch dimension + grid_x = grid_x.unsqueeze(0).expand(image.shape[0], -1, -1) + grid_y = grid_y.unsqueeze(0).expand(image.shape[0], -1, -1) + + # Apply rotation transformation + grid_x_rot = grid_x * cos_a - grid_y * sin_a + grid_y_rot = grid_x * sin_a + grid_y * cos_a + + # Stack and reshape for grid_sample + grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1) + + image = torch.nn.functional.grid_sample( + image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w] + grid, + mode="bilinear", + padding_mode="zeros", + align_corners=False, + ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] + + # Color augmentations for all cameras + # Random brightness + # Use tensor operations instead of .item() for torch.compile compatibility + brightness_factor = ( + 0.7 + torch.rand(1, device=image.device) * 0.6 + ) # Random factor between 0.7 and 1.3 + image = image * brightness_factor + + # Random contrast + # Use tensor operations instead of .item() for torch.compile compatibility + contrast_factor = ( + 0.6 + torch.rand(1, device=image.device) * 0.8 + ) # Random factor between 0.6 and 1.4 + mean = image.mean(dim=[1, 2, 3], keepdim=True) + image = (image - mean) * contrast_factor + mean + + # Random saturation (convert to HSV, modify S, convert back) + # For simplicity, we'll just apply a random scaling to the color channels + # Use tensor operations instead of .item() for torch.compile compatibility + saturation_factor = ( + 0.5 + torch.rand(1, device=image.device) * 1.0 + ) # Random factor between 0.5 and 1.5 + gray = image.mean(dim=-1, keepdim=True) + image = gray + (image - gray) * saturation_factor + + # Clamp values to [0, 1] + image = torch.clamp(image, 0, 1) + + # Back to [-1, 1] + image = image * 2.0 - 1.0 + + # Convert back to [B, C, H, W] format if it was originally channels-first + if is_channels_first: + image = image.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W] + + out_images[key] = image + + # obtain mask + out_masks = {} + for key in out_images: + if key not in observation.image_masks: + # do not mask by default + out_masks[key] = torch.ones(batch_shape, dtype=torch.bool, device=observation.state.device) + else: + out_masks[key] = observation.image_masks[key] + + # Create a simple object with the required attributes instead of using the complex Observation class + class SimpleProcessedObservation: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + return SimpleProcessedObservation( + images=out_images, + image_masks=out_masks, + state=observation.state, + tokenized_prompt=observation.tokenized_prompt, + tokenized_prompt_mask=observation.tokenized_prompt_mask, + token_ar_mask=observation.token_ar_mask, + token_loss_mask=observation.token_loss_mask, + ) diff --git a/tests/policies/pi0_pi05/test_pi05_compile.py b/tests/policies/pi0_pi05/test_pi05_compile.py new file mode 100644 index 000000000..ce940e998 --- /dev/null +++ b/tests/policies/pi0_pi05/test_pi05_compile.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest +import torch + +pytest.importorskip("transformers") + +from lerobot.policies.pi05 import PI05Config # noqa: E402 +from lerobot.policies.pi05.modeling_pi05 import PI05Pytorch # noqa: E402 +from tests.policies.pi0_pi05.utils.torch_compile import ( # noqa: E402 + assert_cache_stability, + assert_compiled_output_matches_eager, + assert_explain_has_no_graph_breaks, + benchmark_runtime, + make_compile_config, + reset_compile_state, +) +from tests.utils import require_cuda # noqa: E402 + +pytestmark = pytest.mark.skipif( + os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true", + reason="torch.compile benchmark is too slow for CI; run manually on GPU nodes", +) + + +def _make_model(*, compile_model): + return PI05Pytorch(make_compile_config(PI05Config, compile_model=compile_model)).cuda().eval() + + +def _make_dummy_inputs(config): + device = torch.device("cuda") + common = { + "images": [torch.randn(1, 3, *config.image_resolution, device=device)], + "img_masks": [torch.ones(1, dtype=torch.bool, device=device)], + "tokens": torch.randint(0, 1024, (1, 5), dtype=torch.long, device=device), + "masks": torch.ones(1, 5, dtype=torch.bool, device=device), + } + forward_kwargs = { + **common, + "actions": torch.randn(1, config.chunk_size, config.max_action_dim, device=device), + "noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device), + "time": torch.rand(1, device=device), + } + sample_kwargs = { + **common, + "noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device), + "num_steps": config.num_inference_steps, + } + return forward_kwargs, sample_kwargs + + +@require_cuda +def test_pi05_torch_compile_forward_and_sample_actions(): + if not hasattr(torch, "compile"): + pytest.skip("torch.compile is not available") + if not torch._dynamo.is_dynamo_supported(): + pytest.skip("torch._dynamo is not supported on this platform") + + torch.manual_seed(0) + eager_model = _make_model(compile_model=False) + torch.manual_seed(0) + compiled_model = _make_model(compile_model=True) + forward_kwargs, sample_kwargs = _make_dummy_inputs(compiled_model.config) + + try: + assert_compiled_output_matches_eager(eager_model, compiled_model, forward_kwargs, sample_kwargs) + + assert_explain_has_no_graph_breaks(eager_model.forward, forward_kwargs, "pi05.forward") + assert_explain_has_no_graph_breaks(eager_model.sample_actions, sample_kwargs, "pi05.sample_actions") + + assert_cache_stability(compiled_model.forward, forward_kwargs, "pi05.forward") + assert_cache_stability(compiled_model.sample_actions, sample_kwargs, "pi05.sample_actions") + + benchmark_runtime(eager_model.forward, compiled_model.forward, forward_kwargs, "pi05.forward") + benchmark_runtime( + eager_model.sample_actions, + compiled_model.sample_actions, + sample_kwargs, + "pi05.sample_actions", + ) + finally: + reset_compile_state() + del eager_model + del compiled_model + torch.cuda.empty_cache() diff --git a/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py b/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py index a965132b0..2ab5f1b94 100644 --- a/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py +++ b/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py @@ -14,52 +14,56 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation""" +"""Compare LeRobot PI0.5 against the vendored OpenPI PyTorch reference.""" +import gc import os -from copy import deepcopy -from typing import Any -import numpy as np import pytest import torch -# Skip if openpi or transformers is not available -pytest.importorskip("openpi") pytest.importorskip("transformers") -# Skip this entire module in CI -pytestmark = pytest.mark.skipif( - os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true", - reason="This test requires local OpenPI installation and is not meant for CI", +from lerobot.configs import PreTrainedConfig # noqa: E402 +from lerobot.policies.pi05 import PI05Policy # noqa: E402 +from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402 +from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402 +from tests.policies.pi0_pi05.openpi_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402 +from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402 + assert_processor_inputs_match_lerobot, + clone_batch, + deterministic_openpi_forward_preprocess, + fix_reference_state_dict, + fixed_flow_sampling, + load_openpi_reference_state_dict, + make_openpi_observation_from_raw, + openpi_model_actions_from_raw, ) -from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing # noqa: E402 +pytestmark = pytest.mark.skipif( + os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true", + reason="OpenPI parity and torch.compile checks are too slow for CI; run manually on GPU nodes", +) -# NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions. -from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402 -from transformers import AutoTokenizer # noqa: E402 - -from lerobot.policies.pi05 import PI05Config, PI05Policy # noqa: E402 -from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402 -from lerobot.processor import PolicyProcessorPipeline # noqa: E402 -from lerobot.types import PolicyAction # noqa: E402 - -# TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG DUMMY_ACTION_DIM = 32 DUMMY_STATE_DIM = 32 DUMMY_ACTION_HORIZON = 50 DUMMY_MAX_TOKEN_LEN = 200 -DEVICE = "cpu" # Use CPU to avoid memory issues for testing +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +COMPILE_MODE = "default" +FORWARD_RTOL = 1e-4 +FORWARD_ATOL = 1e-4 +SAMPLE_RTOL = 1e-2 +SAMPLE_ATOL = 5e-3 DUMMY_DATASET_STATS = { - "observation.state": { + OBS_STATE: { "mean": torch.zeros(DUMMY_STATE_DIM), "std": torch.ones(DUMMY_STATE_DIM), "q01": torch.zeros(DUMMY_STATE_DIM), "q99": torch.ones(DUMMY_STATE_DIM), }, - "action": { + ACTION: { "mean": torch.zeros(DUMMY_ACTION_DIM), "std": torch.ones(DUMMY_ACTION_DIM), "q01": torch.zeros(DUMMY_ACTION_DIM), @@ -88,6 +92,15 @@ DUMMY_DATASET_STATS = { } +@pytest.fixture(autouse=True) +def cleanup_cuda_after_test(): + yield + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + class PI05BaseOriginalConfig: action_dim: int = DUMMY_ACTION_DIM action_horizon: int = DUMMY_ACTION_HORIZON @@ -96,341 +109,163 @@ class PI05BaseOriginalConfig: precision: str = "float32" pi05: bool = True dtype: str = "float32" + pytorch_compile_mode: str | None = None -def instantiate_lerobot_pi05( - from_pretrained: bool = False, -) -> tuple[ - PI05Policy, - PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], - PolicyProcessorPipeline[PolicyAction, PolicyAction], -]: - if from_pretrained: - # Load the policy first - policy = PI05Policy.from_pretrained(pretrained_name_or_path="lerobot/pi05_base", strict=True) - else: - config = PI05Config(max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32") - policy = PI05Policy(config) +def instantiate_lerobot_pi05(*, compile_model: bool = False, gradient_checkpointing: bool = False): + config = PreTrainedConfig.from_pretrained("lerobot/pi05_base") + config.device = str(DEVICE) + config.dtype = "float32" + config.compile_model = compile_model + config.compile_mode = COMPILE_MODE + config.gradient_checkpointing = gradient_checkpointing + policy = PI05Policy.from_pretrained("lerobot/pi05_base", config=config, strict=True) policy.to(DEVICE) - policy.config.device = DEVICE - preprocessor, postprocessor = make_pi05_pre_post_processors( - config=policy.config, dataset_stats=DUMMY_DATASET_STATS - ) - return (policy, preprocessor, postprocessor) + policy.config.device = str(DEVICE) + preprocessor, _ = make_pi05_pre_post_processors(config=policy.config, dataset_stats=DUMMY_DATASET_STATS) + return policy, preprocessor -def instantiate_original_pi05(from_pretrained: bool = False, model_path: str | None = None): - config = PI05BaseOriginalConfig() - policy = PI0Pytorch(config) +def instantiate_original_pi05(): + policy = PI0Pytorch(PI05BaseOriginalConfig()).to(DEVICE) - if from_pretrained: - try: - print("Loading converted PyTorch weights from HuggingFace Hub (lerobot/pi05_base)...") - - # Download the model from HuggingFace Hub - import safetensors.torch - from huggingface_hub import snapshot_download - - # Download the entire repository - if model_path and os.path.exists(model_path): - cache_dir = model_path - print(f"Using cached model from: {cache_dir}") - else: - cache_dir = snapshot_download(repo_id="lerobot/pi05_base", repo_type="model") - print(f"Downloaded model to: {cache_dir}") - - # Try to load safetensors format first - model_file = os.path.join(cache_dir, "model.safetensors") - if os.path.exists(model_file): - state_dict = safetensors.torch.load_file(model_file) - print(f"Loaded {len(state_dict)} parameters from safetensors") - else: - raise FileNotFoundError(f"No safetensors file found in {cache_dir}") - - # Load the state dict into the model - missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False) - - if missing_keys: - print(f"Missing keys: {len(missing_keys)}") - if len(missing_keys) <= 5: - for key in missing_keys: - print(f" - {key}") - else: - for key in missing_keys[:5]: - print(f" - {key}") - print(f" ... and {len(missing_keys) - 5} more") - - if unexpected_keys: - print(f"Unexpected keys: {len(unexpected_keys)}") - if len(unexpected_keys) <= 5: - for key in unexpected_keys: - print(f" - {key}") - else: - for key in unexpected_keys[:5]: - print(f" - {key}") - print(f" ... and {len(unexpected_keys) - 5} more") - - if not missing_keys and not unexpected_keys: - print("All pretrained weights loaded successfully!") - else: - print("Pretrained weights loaded with some missing/unexpected keys (this may be normal)") - - except Exception as e: - print(f"Failed to load pretrained weights: {e}") - print(" Using randomly initialized weights...") - import traceback - - traceback.print_exc() - - policy.to(DEVICE) + # NOTE: `lerobot/pi05_base` ēš„ LeRobot loader 和 PI0 äø€ę ·ä¼šåœØ strict load 前做 key + # å…¼å®¹č½¬ę¢ļ¼Œå› ę­¤é¢„ęœŸę²”ęœ‰ missing_keys ꈖ unexpected_keys怂vendored reference åˆ™ę˜Æč£ø + # `nn.Module`ļ¼Œéœ€č¦åœØęµ‹čÆ•ä¾§č”„é½ checkpoint äøŽęØ”å—å‘½åä¹‹é—“ēš„ęœ€å°å·®å¼‚ć€‚ + # NOTE: `lm_head.weight` 是 PaliGemma tied embedding ēš„äæå­˜åļ¼›LeRobot ēš„ + # from_pretrained ä¼šęŠŠå®ƒę˜ å°„åˆ°å†…éƒØ `embed_tokens.weight`ļ¼Œč€Œ reference ęØ”åž‹ę²”ęœ‰čæ™å±‚ + # loaderļ¼Œę‰€ä»„čæ™é‡Œę‰‹åŠØå¤ē”ØåŒäø€ä»½ tensorļ¼Œéæå…ęŠŠęƒé‡åˆ«åå·®å¼‚čÆÆåˆ¤ęˆęØ”åž‹å·®å¼‚ć€‚ + state_dict = fix_reference_state_dict(load_openpi_reference_state_dict("lerobot/pi05_base")) + missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False) + assert missing_keys == [] + assert unexpected_keys == [] return policy def create_dummy_data(): - batch_size = 2 # Reduce batch size for testing - device = DEVICE - - # Use the exact same prompt for both implementations + batch_size = 2 prompt = "Pick up the red block and place it in the bin" - - batch = { - "observation.state": torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device), - "action": torch.randn( - batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=device + return { + OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE), + ACTION: torch.randn( + batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE ), - # Create images in [0, 1] range as expected by LeRobot (will be converted to [-1, 1] internally) "observation.images.base_0_rgb": torch.rand( - batch_size, 3, 224, 224, dtype=torch.float32, device=device + batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE ), "observation.images.left_wrist_0_rgb": torch.rand( - batch_size, 3, 224, 224, dtype=torch.float32, device=device + batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE ), "observation.images.right_wrist_0_rgb": torch.rand( - batch_size, 3, 224, 224, dtype=torch.float32, device=device + batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE ), - # Add the task prompt for LeRobot - provide as list with single element to trigger expansion "task": [prompt for _ in range(batch_size)], } - return batch -def extract_lerobot_processed_inputs(lerobot_pi0, batch): - """Extract the exact same processed inputs that LeRobot uses internally.""" - # Get the tokenized language from LeRobot's internal method - lang_tokens, lang_masks = lerobot_pi0._tokenize_language(batch) - - # Get the preprocessed images from LeRobot's internal method - images, img_masks = lerobot_pi0._preprocess_images(batch, train=False) - - # Create dummy token_ar_mask and token_loss_mask for original implementation - token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32) - token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool) - - return images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask +def prepare_parity_inputs(lerobot_pi05, lerobot_preprocessor): + torch.manual_seed(0) + raw_batch = create_dummy_data() + lerobot_batch = lerobot_preprocessor(clone_batch(raw_batch)) + openpi_observation = make_openpi_observation_from_raw( + raw_batch, + action_dim=DUMMY_ACTION_DIM, + max_token_len=DUMMY_MAX_TOKEN_LEN, + dataset_stats=DUMMY_DATASET_STATS, + pi05=True, + ) + openpi_actions = openpi_model_actions_from_raw( + raw_batch, + action_dim=DUMMY_ACTION_DIM, + dataset_stats=DUMMY_DATASET_STATS, + pi05=True, + ) + assert_processor_inputs_match_lerobot( + lerobot_pi05, + lerobot_batch, + openpi_observation, + compare_state=False, + ) + batch_size = raw_batch[OBS_STATE].shape[0] + noise = torch.randn( + batch_size, + DUMMY_ACTION_HORIZON, + DUMMY_ACTION_DIM, + dtype=torch.float32, + device=DEVICE, + ) + time = torch.linspace(0.2, 0.8, batch_size, dtype=torch.float32, device=DEVICE) + return lerobot_batch, openpi_observation, openpi_actions, noise, time -class PI05Observation: - """Observation class that matches the original OpenPI format.""" - - def __init__( - self, - state, - images, - image_masks, - tokenized_prompt, - tokenized_prompt_mask, - token_ar_mask, - token_loss_mask, - ): - self.state = state - self.images = images - self.image_masks = image_masks - self.tokenized_prompt = tokenized_prompt - self.tokenized_prompt_mask = tokenized_prompt_mask - self.token_ar_mask = token_ar_mask - self.token_loss_mask = token_loss_mask - - -def create_original_observation_with_openpi_preprocessing(batch): - """Create observation object for OpenPI using OpenPI's own preprocessing with pi05 state tokenizer.""" - batch_size = batch["observation.state"].shape[0] - device = batch["observation.state"].device - - # Create tokenizer for OpenPI (same as LeRobot uses) - tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224") - - # Get task description (pi05 processor handles all text formatting) - tasks = batch.get("task", ["Pick up the object"] * batch_size) - if isinstance(tasks, str): - tasks = [tasks] * batch_size - elif len(tasks) == 1: - tasks = tasks * batch_size - - # Use pi05 state and input tokenizer logic (same as Pi05PrepareStateTokenizerProcessorStep) - state = batch["observation.state"] - state = deepcopy(state) - - # Prepare state (pad to max_state_dim) - from lerobot.policies.pi05.modeling_pi05 import pad_vector - - state = pad_vector(state, DUMMY_STATE_DIM) - - # Normalize state to [-1, 1] range if needed (assuming it's already normalized from normalize_inputs) - # Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`) - state_np = state.cpu().numpy() - discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 - - # Create pi05-formatted prompts that include state information - full_prompts = [] - for i, task in enumerate(tasks): - cleaned_text = task.strip().replace("_", " ").replace("\n", " ") - state_str = " ".join(map(str, discretized_states[i])) - full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: " - full_prompts.append(full_prompt) - - # Tokenize with max_length padding to match OpenPI's expected format - tokenized = tokenizer( - full_prompts, - padding="max_length", - padding_side="right", - truncation=True, - max_length=DUMMY_MAX_TOKEN_LEN, - return_tensors="pt", +def assert_forward_matches(*, compile_model: bool = False, gradient_checkpointing: bool = False): + lerobot_pi05, lerobot_preprocessor = instantiate_lerobot_pi05( + compile_model=compile_model, + gradient_checkpointing=gradient_checkpointing, + ) + original_pi05 = instantiate_original_pi05() + lerobot_batch, openpi_observation, openpi_actions, noise, time = prepare_parity_inputs( + lerobot_pi05, + lerobot_preprocessor, ) - lang_tokens = tokenized["input_ids"].to(device) - lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool) + if gradient_checkpointing: + lerobot_pi05.train() + else: + lerobot_pi05.eval() + original_pi05.eval() - # Create dummy token_ar_mask and token_loss_mask for OpenPI - token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32) - token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool) + with fixed_flow_sampling(lerobot_pi05.model, noise=noise, time=time): + lerobot_loss, _ = lerobot_pi05(lerobot_batch, reduction="none") + with deterministic_openpi_forward_preprocess(original_pi05): + openpi_losses = original_pi05(openpi_observation, openpi_actions, noise=noise, time=time) + openpi_loss = openpi_losses.mean(dim=(1, 2)) - # Convert LeRobot images format to OpenPI format (convert [0,1] to [-1,1] range) - image_dict = { - "base_0_rgb": batch["observation.images.base_0_rgb"] * 2.0 - 1.0, - "left_wrist_0_rgb": batch["observation.images.left_wrist_0_rgb"] * 2.0 - 1.0, - "right_wrist_0_rgb": batch["observation.images.right_wrist_0_rgb"] * 2.0 - 1.0, - } + torch.testing.assert_close(lerobot_loss, openpi_loss, rtol=FORWARD_RTOL, atol=FORWARD_ATOL) - # Create image masks (all ones for real images) - image_masks_dict = {} - for key in image_dict: - image_masks_dict[key] = torch.ones(batch_size, dtype=torch.bool, device=device) - # Create raw observation object (before preprocessing) - raw_observation = PI05Observation( - state=batch["observation.state"], - images=image_dict, - image_masks=image_masks_dict, - tokenized_prompt=lang_tokens, - tokenized_prompt_mask=lang_masks, - token_ar_mask=token_ar_mask, - token_loss_mask=token_loss_mask, +def assert_sample_actions_match_openpi(*, compile_model: bool = False): + lerobot_pi05, lerobot_preprocessor = instantiate_lerobot_pi05(compile_model=compile_model) + original_pi05 = instantiate_original_pi05() + lerobot_batch, openpi_observation, _openpi_actions, noise, _time = prepare_parity_inputs( + lerobot_pi05, + lerobot_preprocessor, ) - # Now use OpenPI's preprocessing - processed_obs = openpi_preprocessing.preprocess_observation_pytorch(raw_observation, train=False) - - return processed_obs - - -def create_original_observation_from_lerobot(lerobot_pi0, batch): - """Create observation object compatible with original OpenPI using the exact same inputs as LeRobot.""" - _batch_size = batch["observation.state"].shape[0] - _device = batch["observation.state"].device - - # Extract the exact same processed inputs that LeRobot uses - images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask = ( - extract_lerobot_processed_inputs(lerobot_pi0, batch) - ) - - # Convert images list to dict with original OpenPI keys - image_dict = { - "base_0_rgb": images[0], - "left_wrist_0_rgb": images[1], - "right_wrist_0_rgb": images[2], - } - - # Convert image masks list to dict with original OpenPI keys - image_masks_dict = { - "base_0_rgb": img_masks[0], - "left_wrist_0_rgb": img_masks[1], - "right_wrist_0_rgb": img_masks[2], - } - - return PI05Observation( - state=batch["observation.state"], - images=image_dict, - image_masks=image_masks_dict, - tokenized_prompt=lang_tokens, - tokenized_prompt_mask=lang_masks, - token_ar_mask=token_ar_mask, - token_loss_mask=token_loss_mask, - ) - - -def test_pi05_original_vs_lerobot(): - """Test PI05 original implementation vs LeRobot implementation.""" - print("Initializing models...") - lerobot_pi05, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_pi05( - from_pretrained=True - ) # Load pretrained LeRobot model - original_pi0 = instantiate_original_pi05( - from_pretrained=True - ) # Load pretrained OpenPI model from HuggingFace Hub - - print("Creating dummy data...") - batch = create_dummy_data() - batch_lerobot = deepcopy(batch) - - # Test each model with its own preprocessing (more realistic end-to-end test) - print("\nTest each model with its own preprocessing") - print("Creating observation for OpenPI using OpenPI's own preprocessing...") - pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch) - - print(f"Task prompt: '{batch['task'][0]}'") - print(f"OpenPI tokenized prompt shape: {pi0_obs_openpi.tokenized_prompt.shape}") - print(f"OpenPI image shapes: {[img.shape for img in pi0_obs_openpi.images.values()]}") - print(f"OpenPI state shape: {pi0_obs_openpi.state.shape}") - - print("Testing OpenPI with own preprocessing...") - original_pi0.eval() - torch.manual_seed(42) # Set seed for reproducibility - batch_size = batch["observation.state"].shape[0] - noise_shape = (batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM) - fixed_noise = torch.randn(noise_shape, dtype=torch.float32, device=DEVICE) - - with torch.no_grad(): - openpi_actions = original_pi0.sample_actions( - device=DEVICE, observation=pi0_obs_openpi, noise=fixed_noise, num_steps=10 - ) - openpi_actions_unit = openpi_actions[:, 0, :] - print(f"OpenPI (own preprocessing) Actions shape: {openpi_actions.shape}") - print(f"OpenPI (own preprocessing) Actions unit shape: {openpi_actions_unit.shape}") - print(f"OpenPI (own preprocessing) Actions mean: {openpi_actions.mean().item():.6f}") - print(f"OpenPI (own preprocessing) Actions std: {openpi_actions.std().item():.6f}") - - print("Testing LeRobot with own preprocessing...") lerobot_pi05.eval() - torch.manual_seed(42) # Set the same seed - - batch_lerobot_processed = lerobot_preprocessor(batch_lerobot) + original_pi05.eval() with torch.no_grad(): - lerobot_actions_own = lerobot_pi05.predict_action_chunk( - batch_lerobot_processed - ) # batch_size, n_action_steps, action_dim - lerobot_actions_unit = lerobot_actions_own[:, 0, :] - print(f"LeRobot (own preprocessing) Actions shape: {lerobot_actions_own.shape}") - print(f"LeRobot (own preprocessing) Actions unit shape: {lerobot_actions_unit.shape}") - print(f"LeRobot (own preprocessing) Actions mean: {lerobot_actions_own.mean().item():.6f}") - print(f"LeRobot (own preprocessing) Actions std: {lerobot_actions_own.std().item():.6f}") + lerobot_actions = lerobot_pi05.predict_action_chunk(lerobot_batch, noise=noise, num_steps=10) + openpi_actions = original_pi05.sample_actions( + device=DEVICE, + observation=openpi_observation, + noise=noise, + num_steps=10, + ) - print("\nComparing end-to-end implementations:") - print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)}") - print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)}") - print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}") + torch.testing.assert_close(lerobot_actions, openpi_actions, rtol=SAMPLE_RTOL, atol=SAMPLE_ATOL) - assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4) - assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2) - assert torch.abs(lerobot_actions_own - openpi_actions).max().item() < 1e-4 + +def test_pi05_forward_matches_openpi(): + assert_forward_matches() + + +def test_pi05_sample_actions_match_openpi(): + assert_sample_actions_match_openpi() + + +def test_pi05_gradient_checkpointing_forward_matches_openpi(): + assert_forward_matches(gradient_checkpointing=True) + + +def test_pi05_compile_forward_matches_openpi(): + assert_forward_matches(compile_model=True) + + +def test_pi05_compile_sample_actions_match_openpi(): + assert_sample_actions_match_openpi(compile_model=True) + + +def test_pi05_compile_gradient_checkpointing_forward_matches_openpi(): + assert_forward_matches(compile_model=True, gradient_checkpointing=True) diff --git a/tests/policies/pi0_pi05/test_pi0_compile.py b/tests/policies/pi0_pi05/test_pi0_compile.py new file mode 100644 index 000000000..4c8f55d4c --- /dev/null +++ b/tests/policies/pi0_pi05/test_pi0_compile.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest +import torch + +pytest.importorskip("transformers") + +from lerobot.policies.pi0 import PI0Config # noqa: E402 +from lerobot.policies.pi0.modeling_pi0 import PI0Pytorch # noqa: E402 +from tests.policies.pi0_pi05.utils.torch_compile import ( # noqa: E402 + assert_cache_stability, + assert_compiled_output_matches_eager, + assert_explain_has_no_graph_breaks, + benchmark_runtime, + make_compile_config, + reset_compile_state, +) +from tests.utils import require_cuda # noqa: E402 + +pytestmark = pytest.mark.skipif( + os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true", + reason="torch.compile benchmark is too slow for CI; run manually on GPU nodes", +) + + +def _make_model(*, compile_model): + return PI0Pytorch(make_compile_config(PI0Config, compile_model=compile_model)).cuda().eval() + + +def _make_dummy_inputs(config): + device = torch.device("cuda") + common = { + "images": [torch.randn(1, 3, *config.image_resolution, device=device)], + "img_masks": [torch.ones(1, dtype=torch.bool, device=device)], + "lang_tokens": torch.randint(0, 1024, (1, 5), dtype=torch.long, device=device), + "lang_masks": torch.ones(1, 5, dtype=torch.bool, device=device), + "state": torch.randn(1, config.max_state_dim, device=device), + } + forward_kwargs = { + **common, + "actions": torch.randn(1, config.chunk_size, config.max_action_dim, device=device), + "noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device), + "time": torch.rand(1, device=device), + } + sample_kwargs = { + **common, + "noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device), + "num_steps": config.num_inference_steps, + } + return forward_kwargs, sample_kwargs + + +@require_cuda +def test_pi0_torch_compile_forward_and_sample_actions(): + if not hasattr(torch, "compile"): + pytest.skip("torch.compile is not available") + if not torch._dynamo.is_dynamo_supported(): + pytest.skip("torch._dynamo is not supported on this platform") + + torch.manual_seed(0) + eager_model = _make_model(compile_model=False) + torch.manual_seed(0) + compiled_model = _make_model(compile_model=True) + forward_kwargs, sample_kwargs = _make_dummy_inputs(compiled_model.config) + + try: + assert_compiled_output_matches_eager(eager_model, compiled_model, forward_kwargs, sample_kwargs) + + assert_explain_has_no_graph_breaks(eager_model.forward, forward_kwargs, "pi0.forward") + assert_explain_has_no_graph_breaks(eager_model.sample_actions, sample_kwargs, "pi0.sample_actions") + + assert_cache_stability(compiled_model.forward, forward_kwargs, "pi0.forward") + assert_cache_stability(compiled_model.sample_actions, sample_kwargs, "pi0.sample_actions") + + benchmark_runtime(eager_model.forward, compiled_model.forward, forward_kwargs, "pi0.forward") + benchmark_runtime( + eager_model.sample_actions, compiled_model.sample_actions, sample_kwargs, "pi0.sample_actions" + ) + finally: + reset_compile_state() + del eager_model + del compiled_model + torch.cuda.empty_cache() diff --git a/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py b/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py index 62e34b70d..9dad90d60 100644 --- a/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py +++ b/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py @@ -14,51 +14,56 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Test script to verify PI0 policy integration with LeRobot vs the original implementation""" +"""Compare LeRobot PI0 against the vendored OpenPI PyTorch reference.""" +import gc import os -from copy import deepcopy -from typing import Any import pytest import torch -# Skip if openpi or transformers is not available -pytest.importorskip("openpi") pytest.importorskip("transformers") -# Skip this entire module in CI -pytestmark = pytest.mark.skipif( - os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true", - reason="This test requires local OpenPI installation and is not meant for CI", +from lerobot.configs import PreTrainedConfig # noqa: E402 +from lerobot.policies.pi0 import PI0Policy # noqa: E402 +from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors # noqa: E402 +from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402 +from tests.policies.pi0_pi05.openpi_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402 +from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402 + assert_processor_inputs_match_lerobot, + clone_batch, + deterministic_openpi_forward_preprocess, + fix_reference_state_dict, + fixed_flow_sampling, + load_openpi_reference_state_dict, + make_openpi_observation_from_raw, + openpi_model_actions_from_raw, ) -from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing # noqa: E402 +pytestmark = pytest.mark.skipif( + os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true", + reason="OpenPI parity and torch.compile checks are too slow for CI; run manually on GPU nodes", +) -# NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions. -from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402 -from transformers import AutoTokenizer # noqa: E402 - -from lerobot.policies.pi0 import PI0Config, PI0Policy # noqa: E402 -from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors # noqa: E402 -from lerobot.processor import PolicyProcessorPipeline # noqa: E402 -from lerobot.types import PolicyAction # noqa: E402 - -# TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG DUMMY_ACTION_DIM = 32 DUMMY_STATE_DIM = 32 DUMMY_ACTION_HORIZON = 50 -DUMMY_MAX_TOKEN_LEN = 48 # Default for PI0 (non-pi05) -DEVICE = "cpu" # Use CPU to avoid memory issues for testing +DUMMY_MAX_TOKEN_LEN = 48 +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +COMPILE_MODE = "default" +FORWARD_RTOL = 1e-4 +FORWARD_ATOL = 1e-4 +SAMPLE_RTOL = 1e-2 +SAMPLE_ATOL = 5e-3 DUMMY_DATASET_STATS = { - "observation.state": { + OBS_STATE: { "mean": torch.zeros(DUMMY_STATE_DIM), "std": torch.ones(DUMMY_STATE_DIM), "q01": torch.zeros(DUMMY_STATE_DIM), "q99": torch.ones(DUMMY_STATE_DIM), }, - "action": { + ACTION: { "mean": torch.zeros(DUMMY_ACTION_DIM), "std": torch.ones(DUMMY_ACTION_DIM), "q01": torch.zeros(DUMMY_ACTION_DIM), @@ -87,6 +92,15 @@ DUMMY_DATASET_STATS = { } +@pytest.fixture(autouse=True) +def cleanup_cuda_after_test(): + yield + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + class PI0BaseOriginalConfig: action_dim: int = DUMMY_ACTION_DIM action_horizon: int = DUMMY_ACTION_HORIZON @@ -95,333 +109,156 @@ class PI0BaseOriginalConfig: precision: str = "float32" pi05: bool = False dtype: str = "float32" + pytorch_compile_mode: str | None = None -def instantiate_lerobot_pi0( - from_pretrained: bool = False, -) -> tuple[ - PI0Policy, - PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], - PolicyProcessorPipeline[PolicyAction, PolicyAction], -]: - if from_pretrained: - # Load the policy first - policy = PI0Policy.from_pretrained(pretrained_name_or_path="lerobot/pi0_base", strict=True) - else: - config = PI0Config(max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32") - policy = PI0Policy(config) +def instantiate_lerobot_pi0(*, compile_model: bool = False, gradient_checkpointing: bool = False): + config = PreTrainedConfig.from_pretrained("lerobot/pi0_base") + config.device = str(DEVICE) + config.dtype = "float32" + config.compile_model = compile_model + config.compile_mode = COMPILE_MODE + config.gradient_checkpointing = gradient_checkpointing + policy = PI0Policy.from_pretrained("lerobot/pi0_base", config=config, strict=True) policy.to(DEVICE) - policy.config.device = DEVICE - preprocessor, postprocessor = make_pi0_pre_post_processors( - config=policy.config, dataset_stats=DUMMY_DATASET_STATS - ) - return (policy, preprocessor, postprocessor) + policy.config.device = str(DEVICE) + preprocessor, _ = make_pi0_pre_post_processors(config=policy.config, dataset_stats=DUMMY_DATASET_STATS) + return policy, preprocessor -def instantiate_original_pi0(from_pretrained: bool = False, model_path: str = None): - config = PI0BaseOriginalConfig() - policy = PI0Pytorch(config) - - if from_pretrained: - try: - print("Loading converted PyTorch weights from HuggingFace Hub (lerobot/pi0_base)...") - - # Download the model from HuggingFace Hub - import safetensors.torch - from huggingface_hub import snapshot_download - - # Download the entire repository - if model_path and os.path.exists(model_path): - cache_dir = model_path - print(f"Using cached model from: {cache_dir}") - else: - cache_dir = snapshot_download(repo_id="lerobot/pi0_base", repo_type="model") - print(f"Downloaded model to: {cache_dir}") - - # Try to load safetensors format first - model_file = os.path.join(cache_dir, "model.safetensors") - if os.path.exists(model_file): - state_dict = safetensors.torch.load_file(model_file) - print(f"Loaded {len(state_dict)} parameters from safetensors") - else: - raise FileNotFoundError(f"No safetensors file found in {cache_dir}") - - # Load the state dict into the model - missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False) - - if missing_keys: - print(f"Missing keys: {len(missing_keys)}") - if len(missing_keys) <= 5: - for key in missing_keys: - print(f" - {key}") - else: - for key in missing_keys[:5]: - print(f" - {key}") - print(f" ... and {len(missing_keys) - 5} more") - - if unexpected_keys: - print(f"Unexpected keys: {len(unexpected_keys)}") - if len(unexpected_keys) <= 5: - for key in unexpected_keys: - print(f" - {key}") - else: - for key in unexpected_keys[:5]: - print(f" - {key}") - print(f" ... and {len(unexpected_keys) - 5} more") - - if not missing_keys and not unexpected_keys: - print("All pretrained weights loaded successfully!") - else: - print("Pretrained weights loaded with some missing/unexpected keys (this may be normal)") - - except Exception as e: - print(f"Failed to load pretrained weights: {e}") - print(" Using randomly initialized weights...") - import traceback - - traceback.print_exc() - - policy.to(DEVICE) +def instantiate_original_pi0(): + policy = PI0Pytorch(PI0BaseOriginalConfig()).to(DEVICE) + state_dict = fix_reference_state_dict(load_openpi_reference_state_dict("lerobot/pi0_base")) + missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False) + assert missing_keys == [] + assert unexpected_keys == [] return policy def create_dummy_data(): - batch_size = 2 # Reduce batch size for testing - device = DEVICE - - # Use the exact same prompt for both implementations + batch_size = 2 prompt = "Pick up the red block and place it in the bin" - - batch = { - "observation.state": torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device), - "action": torch.randn( - batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=device + return { + OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE), + ACTION: torch.randn( + batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE ), - # Create images in [0, 1] range as expected by LeRobot (will be converted to [-1, 1] internally) "observation.images.base_0_rgb": torch.rand( - batch_size, 3, 224, 224, dtype=torch.float32, device=device + batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE ), "observation.images.left_wrist_0_rgb": torch.rand( - batch_size, 3, 224, 224, dtype=torch.float32, device=device + batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE ), "observation.images.right_wrist_0_rgb": torch.rand( - batch_size, 3, 224, 224, dtype=torch.float32, device=device + batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE ), - # Add the task prompt for LeRobot - provide as list with single element to trigger expansion "task": [prompt for _ in range(batch_size)], } - return batch -def extract_lerobot_processed_inputs(lerobot_pi0, batch): - """Extract the exact same processed inputs that LeRobot uses internally.""" - # Get the tokenized language from LeRobot's internal method - lang_tokens, lang_masks = lerobot_pi0._tokenize_language(batch) - - # Get the preprocessed images from LeRobot's internal method - images, img_masks = lerobot_pi0._preprocess_images(batch, train=False) - - # Create dummy token_ar_mask and token_loss_mask for original implementation - token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32) - token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool) - - return images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask +def prepare_parity_inputs(lerobot_pi0, lerobot_preprocessor): + torch.manual_seed(0) + raw_batch = create_dummy_data() + lerobot_batch = lerobot_preprocessor(clone_batch(raw_batch)) + openpi_observation = make_openpi_observation_from_raw( + raw_batch, + action_dim=DUMMY_ACTION_DIM, + max_token_len=DUMMY_MAX_TOKEN_LEN, + dataset_stats=DUMMY_DATASET_STATS, + pi05=False, + ) + openpi_actions = openpi_model_actions_from_raw( + raw_batch, + action_dim=DUMMY_ACTION_DIM, + dataset_stats=DUMMY_DATASET_STATS, + pi05=False, + ) + assert_processor_inputs_match_lerobot( + lerobot_pi0, + lerobot_batch, + openpi_observation, + compare_state=True, + ) + batch_size = raw_batch[OBS_STATE].shape[0] + noise = torch.randn( + batch_size, + DUMMY_ACTION_HORIZON, + DUMMY_ACTION_DIM, + dtype=torch.float32, + device=DEVICE, + ) + time = torch.linspace(0.2, 0.8, batch_size, dtype=torch.float32, device=DEVICE) + return lerobot_batch, openpi_observation, openpi_actions, noise, time -class PI0Observation: - """Observation class that matches the original OpenPI format.""" +def assert_forward_matches(*, compile_model: bool = False, gradient_checkpointing: bool = False): + lerobot_pi0, lerobot_preprocessor = instantiate_lerobot_pi0( + compile_model=compile_model, + gradient_checkpointing=gradient_checkpointing, + ) + original_pi0 = instantiate_original_pi0() + lerobot_batch, openpi_observation, openpi_actions, noise, time = prepare_parity_inputs( + lerobot_pi0, + lerobot_preprocessor, + ) - def __init__( - self, - state, - images, - image_masks, - tokenized_prompt, - tokenized_prompt_mask, - token_ar_mask, - token_loss_mask, - ): - self.state = state - self.images = images - self.image_masks = image_masks - self.tokenized_prompt = tokenized_prompt - self.tokenized_prompt_mask = tokenized_prompt_mask - self.token_ar_mask = token_ar_mask - self.token_loss_mask = token_loss_mask - - -def create_original_observation_with_openpi_preprocessing(batch): - """Create observation object for OpenPI using OpenPI's own preprocessing.""" - batch_size = batch["observation.state"].shape[0] - device = batch["observation.state"].device - - # Create tokenizer for OpenPI (same as LeRobot uses) - tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224") - - # Get task description - if "task" in batch: - tasks = batch["task"] - if isinstance(tasks, str): - # Single string: add newline if not present, then convert to list - if not tasks.endswith("\n"): - tasks = f"{tasks}\n" - tasks = [tasks] - elif isinstance(tasks, list) and all(isinstance(t, str) for t in tasks): - # List of strings: add newline to each if not present - tasks = [t if t.endswith("\n") else f"{t}\n" for t in tasks] - if len(tasks) == 1: - # Expand to batch size - tasks = tasks * batch_size - if len(tasks) != batch_size: - raise ValueError(f"Expected batch size {batch_size}, got {len(tasks)}") - # If task is neither string nor list of strings, leave unchanged + if gradient_checkpointing: + lerobot_pi0.train() else: - # Default task if not provided - tasks = ["Pick up the object\n"] * batch_size - - # Tokenize with max_length padding to match OpenPI's expected format - tokenized = tokenizer( - tasks, - padding="max_length", - padding_side="right", - truncation=True, - max_length=DUMMY_MAX_TOKEN_LEN, - return_tensors="pt", - ) - - lang_tokens = tokenized["input_ids"].to(device) - lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool) - - # Create dummy token_ar_mask and token_loss_mask for OpenPI - token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32) - token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool) - - # Convert LeRobot images format to OpenPI format (convert [0,1] to [-1,1] range) - image_dict = { - "base_0_rgb": batch["observation.images.base_0_rgb"] * 2.0 - 1.0, - "left_wrist_0_rgb": batch["observation.images.left_wrist_0_rgb"] * 2.0 - 1.0, - "right_wrist_0_rgb": batch["observation.images.right_wrist_0_rgb"] * 2.0 - 1.0, - } - - # Create image masks (all ones for real images) - image_masks_dict = {} - for key in image_dict: - image_masks_dict[key] = torch.ones(batch_size, dtype=torch.bool, device=device) - - # Create raw observation object (before preprocessing) - raw_observation = PI0Observation( - state=batch["observation.state"], - images=image_dict, - image_masks=image_masks_dict, - tokenized_prompt=lang_tokens, - tokenized_prompt_mask=lang_masks, - token_ar_mask=token_ar_mask, - token_loss_mask=token_loss_mask, - ) - - # Now use OpenPI's preprocessing - processed_obs = openpi_preprocessing.preprocess_observation_pytorch(raw_observation, train=False) - - return processed_obs - - -def create_original_observation_from_lerobot(lerobot_pi0, batch): - """Create observation object compatible with original OpenPI using the exact same inputs as LeRobot.""" - _batch_size = batch["observation.state"].shape[0] - _device = batch["observation.state"].device - - # Extract the exact same processed inputs that LeRobot uses - images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask = ( - extract_lerobot_processed_inputs(lerobot_pi0, batch) - ) - - # Convert images list to dict with original OpenPI keys - image_dict = { - "base_0_rgb": images[0], - "left_wrist_0_rgb": images[1], - "right_wrist_0_rgb": images[2], - } - - # Convert image masks list to dict with original OpenPI keys - image_masks_dict = { - "base_0_rgb": img_masks[0], - "left_wrist_0_rgb": img_masks[1], - "right_wrist_0_rgb": img_masks[2], - } - - return PI0Observation( - state=batch["observation.state"], - images=image_dict, - image_masks=image_masks_dict, - tokenized_prompt=lang_tokens, - tokenized_prompt_mask=lang_masks, - token_ar_mask=token_ar_mask, - token_loss_mask=token_loss_mask, - ) - - -def test_pi0_original_vs_lerobot(): - """Test PI0 original implementation vs LeRobot implementation.""" - print("Initializing models...") - lerobot_pi0, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_pi0( - from_pretrained=True - ) # Load pretrained LeRobot model - original_pi0 = instantiate_original_pi0( - from_pretrained=True - ) # Load pretrained OpenPI model from HuggingFace Hub - - print("Creating dummy data...") - batch = create_dummy_data() - batch_lerobot = deepcopy(batch) - - # Test each model with its own preprocessing (more realistic end-to-end test) - print("\nTest each model with its own preprocessing") - print("Creating observation for OpenPI using OpenPI's own preprocessing...") - pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch) - - print(f"Task prompt: '{batch['task'][0]}'") - print(f"OpenPI tokenized prompt shape: {pi0_obs_openpi.tokenized_prompt.shape}") - print(f"OpenPI image shapes: {[img.shape for img in pi0_obs_openpi.images.values()]}") - print(f"OpenPI state shape: {pi0_obs_openpi.state.shape}") - - print("Testing OpenPI with own preprocessing...") + lerobot_pi0.eval() original_pi0.eval() - torch.manual_seed(42) # Set seed for reproducibility - batch_size = batch["observation.state"].shape[0] - noise_shape = (batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM) - fixed_noise = torch.randn(noise_shape, dtype=torch.float32, device=DEVICE) - with torch.no_grad(): - openpi_actions = original_pi0.sample_actions( - device=DEVICE, observation=pi0_obs_openpi, noise=fixed_noise, num_steps=10 - ) - openpi_actions_unit = openpi_actions[:, 0, :] - print(f"OpenPI (own preprocessing) Actions shape: {openpi_actions.shape}") - print(f"OpenPI (own preprocessing) Actions unit shape: {openpi_actions_unit.shape}") - print(f"OpenPI (own preprocessing) Actions mean: {openpi_actions.mean().item():.6f}") - print(f"OpenPI (own preprocessing) Actions std: {openpi_actions.std().item():.6f}") + with fixed_flow_sampling(lerobot_pi0.model, noise=noise, time=time): + lerobot_loss, _ = lerobot_pi0(lerobot_batch, reduction="none") + with deterministic_openpi_forward_preprocess(original_pi0): + openpi_losses = original_pi0(openpi_observation, openpi_actions, noise=noise, time=time) + openpi_loss = openpi_losses.mean(dim=(1, 2)) + + torch.testing.assert_close(lerobot_loss, openpi_loss, rtol=FORWARD_RTOL, atol=FORWARD_ATOL) + + +def assert_sample_actions_match_openpi(*, compile_model: bool = False): + lerobot_pi0, lerobot_preprocessor = instantiate_lerobot_pi0(compile_model=compile_model) + original_pi0 = instantiate_original_pi0() + lerobot_batch, openpi_observation, _openpi_actions, noise, _time = prepare_parity_inputs( + lerobot_pi0, + lerobot_preprocessor, + ) - print("Testing LeRobot with own preprocessing...") lerobot_pi0.eval() - torch.manual_seed(42) # Set the same seed - - batch_lerobot_processed = lerobot_preprocessor(batch_lerobot) + original_pi0.eval() with torch.no_grad(): - lerobot_actions_own = lerobot_pi0.predict_action_chunk( - batch_lerobot_processed - ) # batch_size, n_action_steps, action_dim - lerobot_actions_unit = lerobot_actions_own[:, 0, :] - print(f"LeRobot (own preprocessing) Actions shape: {lerobot_actions_own.shape}") - print(f"LeRobot (own preprocessing) Actions unit shape: {lerobot_actions_unit.shape}") - print(f"LeRobot (own preprocessing) Actions mean: {lerobot_actions_own.mean().item():.6f}") - print(f"LeRobot (own preprocessing) Actions std: {lerobot_actions_own.std().item():.6f}") + lerobot_actions = lerobot_pi0.predict_action_chunk(lerobot_batch, noise=noise, num_steps=10) + openpi_actions = original_pi0.sample_actions( + device=DEVICE, + observation=openpi_observation, + noise=noise, + num_steps=10, + ) - print("\nComparing end-to-end implementations:") - print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)}") - print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)}") - print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}") + torch.testing.assert_close(lerobot_actions, openpi_actions, rtol=SAMPLE_RTOL, atol=SAMPLE_ATOL) - assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4) - assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2) - assert torch.abs(lerobot_actions_own - openpi_actions).max().item() < 1e-4 + +def test_pi0_forward_matches_openpi(): + assert_forward_matches() + + +def test_pi0_sample_actions_match_openpi(): + assert_sample_actions_match_openpi() + + +def test_pi0_gradient_checkpointing_forward_matches_openpi(): + assert_forward_matches(gradient_checkpointing=True) + + +def test_pi0_compile_forward_matches_openpi(): + assert_forward_matches(compile_model=True) + + +def test_pi0_compile_sample_actions_match_openpi(): + assert_sample_actions_match_openpi(compile_model=True) + + +def test_pi0_compile_gradient_checkpointing_forward_matches_openpi(): + assert_forward_matches(compile_model=True, gradient_checkpointing=True) diff --git a/tests/policies/pi0_pi05/utils/__init__.py b/tests/policies/pi0_pi05/utils/__init__.py new file mode 100644 index 000000000..9a7a15a09 --- /dev/null +++ b/tests/policies/pi0_pi05/utils/__init__.py @@ -0,0 +1 @@ +"""Utilities shared by PI0/PI05 policy tests.""" diff --git a/tests/policies/pi0_pi05/utils/openpi_parity.py b/tests/policies/pi0_pi05/utils/openpi_parity.py new file mode 100644 index 000000000..f66e4e473 --- /dev/null +++ b/tests/policies/pi0_pi05/utils/openpi_parity.py @@ -0,0 +1,291 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +import numpy as np +import safetensors.torch +import torch +import torch.nn.functional as F # noqa: N812 +from huggingface_hub import snapshot_download +from transformers import AutoTokenizer + +from lerobot.utils.constants import ( + ACTION, + OBS_LANGUAGE_ATTENTION_MASK, + OBS_LANGUAGE_TOKENS, + OBS_STATE, +) +from tests.policies.pi0_pi05.openpi_pytorch import preprocessing_pytorch as openpi_preprocessing + +IMAGE_KEYS = ("base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb") +TOKENIZER_NAME = "google/paligemma-3b-pt-224" + + +@dataclass +class OpenPIObservation: + state: torch.Tensor + images: dict[str, torch.Tensor] + image_masks: dict[str, torch.Tensor] + tokenized_prompt: torch.Tensor + tokenized_prompt_mask: torch.Tensor + token_ar_mask: torch.Tensor + token_loss_mask: torch.Tensor + + +@lru_cache(maxsize=1) +def paligemma_tokenizer(): + return AutoTokenizer.from_pretrained(TOKENIZER_NAME) + + +def clone_batch(batch: dict) -> dict: + return { + key: value.clone() if isinstance(value, torch.Tensor) else list(value) for key, value in batch.items() + } + + +def pad_last_dim(tensor: torch.Tensor, target_dim: int) -> torch.Tensor: + if tensor.shape[-1] > target_dim: + raise ValueError(f"Cannot pad last dimension {tensor.shape[-1]} down to {target_dim}") + return F.pad(tensor, (0, target_dim - tensor.shape[-1])) + + +def mean_std_normalize(tensor: torch.Tensor, stats: dict[str, torch.Tensor]) -> torch.Tensor: + mean = stats["mean"].to(device=tensor.device, dtype=tensor.dtype) + std = stats["std"].to(device=tensor.device, dtype=tensor.dtype) + return (tensor - mean) / (std + 1e-8) + + +def quantile_normalize(tensor: torch.Tensor, stats: dict[str, torch.Tensor]) -> torch.Tensor: + q01 = stats["q01"].to(device=tensor.device, dtype=tensor.dtype) + q99 = stats["q99"].to(device=tensor.device, dtype=tensor.dtype) + denom = torch.where(q99 == q01, torch.full_like(q99, 1e-8), q99 - q01) + return 2.0 * (tensor - q01) / denom - 1.0 + + +def openpi_model_state_from_raw( + batch: dict[str, torch.Tensor], + *, + action_dim: int, + dataset_stats: dict[str, dict[str, torch.Tensor]], + pi05: bool, +) -> torch.Tensor: + state = batch[OBS_STATE].to(dtype=torch.float32) + if pi05: + state = quantile_normalize(state, dataset_stats[OBS_STATE]) + else: + state = mean_std_normalize(state, dataset_stats[OBS_STATE]) + return pad_last_dim(state, action_dim) + + +def openpi_model_actions_from_raw( + batch: dict[str, torch.Tensor], + *, + action_dim: int, + dataset_stats: dict[str, dict[str, torch.Tensor]], + pi05: bool, +) -> torch.Tensor: + actions = batch[ACTION].to(dtype=torch.float32) + if pi05: + actions = quantile_normalize(actions, dataset_stats[ACTION]) + else: + actions = mean_std_normalize(actions, dataset_stats[ACTION]) + return pad_last_dim(actions, action_dim) + + +def _tasks_from_raw(batch: dict, batch_size: int) -> list[str]: + tasks = batch.get("task") + if tasks is None: + raise ValueError("The parity batch must include a task prompt.") + if isinstance(tasks, str): + return [tasks] * batch_size + if len(tasks) == 1: + return [tasks[0]] * batch_size + if len(tasks) != batch_size: + raise ValueError(f"Expected {batch_size} task prompts, got {len(tasks)}") + return list(tasks) + + +def _format_pi0_prompts(tasks: list[str]) -> list[str]: + return [f"{task.strip().replace('_', ' ').replace(chr(10), ' ')}\n" for task in tasks] + + +def _format_pi05_prompts(tasks: list[str], normalized_state: torch.Tensor) -> list[str]: + state_np = normalized_state.detach().cpu().numpy() + discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 + prompts = [] + for task, state in zip(tasks, discretized_states, strict=True): + cleaned_text = task.strip().replace("_", " ").replace("\n", " ") + state_str = " ".join(map(str, state)) + prompts.append(f"Task: {cleaned_text}, State: {state_str};\nAction: ") + return prompts + + +def _tokenize_prompts(prompts: list[str], *, max_token_len: int, device: torch.device | str): + tokenized = paligemma_tokenizer()( + prompts, + padding="max_length", + padding_side="right", + truncation=True, + max_length=max_token_len, + return_tensors="pt", + ) + tokens = tokenized["input_ids"].to(device) + masks = tokenized["attention_mask"].to(device=device, dtype=torch.bool) + return tokens, masks + + +def make_openpi_observation_from_raw( + batch: dict[str, torch.Tensor], + *, + action_dim: int, + max_token_len: int, + dataset_stats: dict[str, dict[str, torch.Tensor]], + pi05: bool, +) -> OpenPIObservation: + batch_size = batch[OBS_STATE].shape[0] + device = batch[OBS_STATE].device + state = openpi_model_state_from_raw( + batch, + action_dim=action_dim, + dataset_stats=dataset_stats, + pi05=pi05, + ) + + tasks = _tasks_from_raw(batch, batch_size) + prompts = _format_pi05_prompts(tasks, state) if pi05 else _format_pi0_prompts(tasks) + tokens, masks = _tokenize_prompts(prompts, max_token_len=max_token_len, device=device) + + images = { + key: batch[f"observation.images.{key}"].to(device=device, dtype=torch.float32) * 2.0 - 1.0 + for key in IMAGE_KEYS + } + image_masks = {key: torch.ones(batch_size, dtype=torch.bool, device=device) for key in IMAGE_KEYS} + + return OpenPIObservation( + state=state, + images=images, + image_masks=image_masks, + tokenized_prompt=tokens, + tokenized_prompt_mask=masks, + token_ar_mask=torch.zeros_like(tokens, dtype=torch.int32), + token_loss_mask=torch.ones_like(masks, dtype=torch.bool), + ) + + +def assert_processor_inputs_match_lerobot( + lerobot_policy, + lerobot_batch: dict[str, torch.Tensor], + openpi_observation: OpenPIObservation, + *, + compare_state: bool, +): + openpi_processed = openpi_preprocessing.preprocess_observation_pytorch(openpi_observation, train=False) + lerobot_images, lerobot_image_masks = lerobot_policy._preprocess_images(lerobot_batch) + + # Token IDs, token masks, images, image masks, and PI0 state are intentionally built from the same + # raw batch through independent LeRobot/OpenPI-style processor logic. They must be bitwise equal. + torch.testing.assert_close( + openpi_observation.tokenized_prompt, lerobot_batch[OBS_LANGUAGE_TOKENS], rtol=0, atol=0 + ) + torch.testing.assert_close( + openpi_observation.tokenized_prompt_mask, + lerobot_batch[OBS_LANGUAGE_ATTENTION_MASK], + rtol=0, + atol=0, + ) + + for openpi_image, lerobot_image in zip(openpi_processed.images.values(), lerobot_images, strict=True): + torch.testing.assert_close(openpi_image, lerobot_image, rtol=0, atol=0) + + for openpi_mask, lerobot_mask in zip( + openpi_processed.image_masks.values(), lerobot_image_masks, strict=True + ): + torch.testing.assert_close(openpi_mask, lerobot_mask, rtol=0, atol=0) + + if compare_state: + torch.testing.assert_close( + openpi_processed.state, lerobot_policy.prepare_state(lerobot_batch), rtol=0, atol=0 + ) + + +def load_openpi_reference_state_dict(repo_id: str) -> dict[str, torch.Tensor]: + cache_dir = Path(snapshot_download(repo_id=repo_id, repo_type="model")) + return safetensors.torch.load_file(cache_dir / "model.safetensors") + + +def fix_reference_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + fixed_state_dict = dict(state_dict) + lm_head_key = "paligemma_with_expert.paligemma.lm_head.weight" + embed_tokens_key = "paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight" + if lm_head_key in fixed_state_dict and embed_tokens_key not in fixed_state_dict: + fixed_state_dict[embed_tokens_key] = fixed_state_dict[lm_head_key].clone() + return fixed_state_dict + + +@contextmanager +def fixed_flow_sampling(model, *, noise: torch.Tensor, time: torch.Tensor) -> Iterator[None]: + original_sample_noise = model.sample_noise + original_sample_time = model.sample_time + + def sample_noise(shape, device): + if tuple(shape) != tuple(noise.shape): + raise ValueError(f"Expected noise shape {tuple(noise.shape)}, got {tuple(shape)}") + return noise.to(device=device) + + def sample_time(batch_size, device): + if batch_size != time.shape[0]: + raise ValueError(f"Expected time batch size {time.shape[0]}, got {batch_size}") + return time.to(device=device) + + model.sample_noise = sample_noise + model.sample_time = sample_time + try: + yield + finally: + model.sample_noise = original_sample_noise + model.sample_time = original_sample_time + + +@contextmanager +def deterministic_openpi_forward_preprocess(openpi_policy) -> Iterator[None]: + """Disable OpenPI's training-time image augmentation only inside a parity forward block. + + OpenPI's `forward()` calls `_preprocess_observation(..., train=True)`, which can apply stochastic + image augmentation. LeRobot's policy forward path does not apply that augmentation, so parity would + otherwise compare two different image tensors rather than two model implementations. The context manager + keeps the public `openpi_policy.forward(observation, ...)` call while making preprocessing deterministic. + + `yield` marks the body of the caller's `with` block. The `try/finally` restores the original method even + if the assertion inside the block fails, so the temporary monkeypatch cannot leak into later tests. + """ + + original_preprocess_observation = openpi_policy._preprocess_observation + + def preprocess_observation(observation, *, train=True): + return original_preprocess_observation(observation, train=False) + + openpi_policy._preprocess_observation = preprocess_observation + try: + yield + finally: + openpi_policy._preprocess_observation = original_preprocess_observation diff --git a/tests/policies/pi0_pi05/utils/torch_compile.py b/tests/policies/pi0_pi05/utils/torch_compile.py new file mode 100644 index 000000000..2e71d15bb --- /dev/null +++ b/tests/policies/pi0_pi05/utils/torch_compile.py @@ -0,0 +1,207 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from collections.abc import Callable + +import torch +from torch._dynamo.utils import counters, guard_failures +from torch.profiler import ProfilerActivity + +FORWARD_RTOL = 1e-5 +FORWARD_ATOL = 5e-2 +SAMPLE_RTOL = 1e-5 +SAMPLE_ATOL = 1e-2 +COMPILE_MODE = "max-autotune" +STEADY_STATE_WARMUPS = 3 +STEADY_STATE_REPEATS = 3 + + +def make_compile_config(config_cls, *, compile_model): + return config_cls(device="cuda", compile_model=compile_model, compile_mode=COMPILE_MODE) + + +def counter_total(name): + return sum(counters.get(name, {}).values()) + + +def compile_snapshot(): + return { + "graph_breaks": counter_total("graph_break"), + "recompiles": counter_total("recompiles"), + "recompile_limits": counter_total("recompile_limit"), + "unique_graphs": counters["stats"].get("unique_graphs", 0), + } + + +def reset_compile_state(): + torch._dynamo.reset() + counters.clear() + guard_failures.clear() + + +def clone_cuda_graph_output(output): + if torch.is_tensor(output): + return output.clone() + if isinstance(output, tuple): + return tuple(clone_cuda_graph_output(item) for item in output) + if isinstance(output, list): + return [clone_cuda_graph_output(item) for item in output] + if isinstance(output, dict): + return {key: clone_cuda_graph_output(value) for key, value in output.items()} + return output + + +def run_model_step(fn: Callable, kwargs: dict): + if hasattr(torch.compiler, "cudagraph_mark_step_begin"): + torch.compiler.cudagraph_mark_step_begin() + return fn(**kwargs) + + +def assert_explain_has_no_graph_breaks(fn: Callable, kwargs: dict, label: str): + reset_compile_state() + explanation = torch._dynamo.explain(fn)(**kwargs) + + assert explanation.graph_count > 0, f"{label} was not captured by Dynamo" + assert explanation.graph_break_count == 0, ( + f"{label} has {explanation.graph_break_count} graph break(s): {explanation.break_reasons}" + ) + assert not explanation.break_reasons, f"{label} graph break reasons: {explanation.break_reasons}" + + print( + f"{label} capture: graphs={explanation.graph_count}, " + f"graph_breaks={explanation.graph_break_count}, ops={explanation.op_count}, " + f"guards={len(explanation.out_guards or [])}" + ) + return explanation + + +@torch.no_grad() +def assert_compiled_output_matches_eager(eager_model, compiled_model, forward_kwargs, sample_kwargs): + eager_forward = eager_model.forward(**forward_kwargs) + compiled_forward = compiled_model.forward(**forward_kwargs) + torch.testing.assert_close(compiled_forward, eager_forward, rtol=FORWARD_RTOL, atol=FORWARD_ATOL) + + eager_actions = eager_model.sample_actions(**sample_kwargs) + compiled_actions = compiled_model.sample_actions(**sample_kwargs) + torch.testing.assert_close(compiled_actions, eager_actions, rtol=SAMPLE_RTOL, atol=SAMPLE_ATOL) + + +@torch.no_grad() +def assert_cache_stability(fn: Callable, kwargs: dict, label: str): + reset_compile_state() + + first_output = clone_cuda_graph_output(run_model_step(fn, kwargs)) + first_snapshot = compile_snapshot() + second_output = clone_cuda_graph_output(run_model_step(fn, kwargs)) + second_snapshot = compile_snapshot() + third_output = clone_cuda_graph_output(run_model_step(fn, kwargs)) + third_snapshot = compile_snapshot() + + torch.testing.assert_close(second_output, first_output, rtol=FORWARD_RTOL, atol=FORWARD_ATOL) + torch.testing.assert_close(third_output, first_output, rtol=FORWARD_RTOL, atol=FORWARD_ATOL) + assert first_snapshot["unique_graphs"] > 0, f"{label} did not compile any graph" + assert third_snapshot["graph_breaks"] == 0, f"{label} graph breaks: {third_snapshot}" + assert third_snapshot["recompiles"] == 0, f"{label} recompiled: {third_snapshot}" + assert third_snapshot["recompile_limits"] == 0, f"{label} hit recompile limit: {third_snapshot}" + assert second_snapshot["unique_graphs"] == first_snapshot["unique_graphs"], ( + f"{label} compiled new graph on second call: first={first_snapshot}, second={second_snapshot}" + ) + assert third_snapshot["unique_graphs"] == first_snapshot["unique_graphs"], ( + f"{label} compiled new graph on third call: first={first_snapshot}, third={third_snapshot}" + ) + assert not guard_failures, f"{label} guard failures: {dict(guard_failures)}" + + print(f"{label} cache: first={first_snapshot}, third={third_snapshot}") + + +@torch.no_grad() +def benchmark_runtime(eager_fn: Callable, compiled_fn: Callable, kwargs: dict, label: str): + run_warmups(eager_fn, kwargs) + run_warmups(compiled_fn, kwargs) + torch.cuda.synchronize() + + eager_metrics = profile_callable(eager_fn, kwargs) + compiled_metrics = profile_callable(compiled_fn, kwargs) + speedup = eager_metrics["cuda_event_ms"] / compiled_metrics["cuda_event_ms"] + + print( + f"{label} runtime: eager_cuda={eager_metrics['cuda_event_ms']:.3f} ms, " + f"compiled_cuda={compiled_metrics['cuda_event_ms']:.3f} ms, speedup={speedup:.3f}x, " + f"host_wall_ms eager/compiled={eager_metrics['host_wall_ms']:.3f}/" + f"{compiled_metrics['host_wall_ms']:.3f}, " + f"cpu_self_time_ms eager/compiled={eager_metrics['cpu_self_time_ms']:.3f}/" + f"{compiled_metrics['cpu_self_time_ms']:.3f}, " + f"cuda_launches eager/compiled={eager_metrics['cuda_launch_count']}/" + f"{compiled_metrics['cuda_launch_count']}, " + f"profiler_events eager/compiled={eager_metrics['profiler_event_count']}/" + f"{compiled_metrics['profiler_event_count']}, " + f"peak_mem_mib eager/compiled={eager_metrics['peak_mem_mib']:.1f}/" + f"{compiled_metrics['peak_mem_mib']:.1f}" + ) + + assert eager_metrics["cuda_event_ms"] > 0 + assert compiled_metrics["cuda_event_ms"] > 0 + assert eager_metrics["profiler_event_count"] > 0 + assert compiled_metrics["profiler_event_count"] > 0 + return eager_metrics, compiled_metrics + + +def run_warmups(fn: Callable, kwargs: dict): + for _ in range(STEADY_STATE_WARMUPS): + run_model_step(fn, kwargs) + torch.cuda.synchronize() + + +def profile_callable(fn: Callable, kwargs: dict): + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + host_start = time.perf_counter() + start_event.record() + for _ in range(STEADY_STATE_REPEATS): + run_model_step(fn, kwargs) + end_event.record() + torch.cuda.synchronize() + cuda_event_ms = start_event.elapsed_time(end_event) / STEADY_STATE_REPEATS + host_wall_ms = (time.perf_counter() - host_start) * 1000 / STEADY_STATE_REPEATS + peak_mem_mib = torch.cuda.max_memory_allocated() / 1024**2 + + with torch.profiler.profile( + activities=[ProfilerActivity.CPU], + ) as profiler: + run_model_step(fn, kwargs) + torch.cuda.synchronize() + + key_averages = profiler.key_averages() + cpu_self_time_ms = sum(event.self_cpu_time_total for event in key_averages) / 1000 + cuda_launch_count = sum( + event.count + for event in key_averages + if event.key in {"cudaLaunchKernel", "cudaGraphLaunch", "cudaLaunchKernelExC"} + ) + profiler_event_count = sum(event.count for event in key_averages) + + return { + "cuda_event_ms": cuda_event_ms, + "host_wall_ms": host_wall_ms, + "cpu_self_time_ms": cpu_self_time_ms, + "cuda_launch_count": cuda_launch_count, + "profiler_event_count": profiler_event_count, + "peak_mem_mib": peak_mem_mib, + } diff --git a/tests/processor/test_pi05_processor.py b/tests/processor/test_pi05_processor.py new file mode 100644 index 000000000..b3dd85f45 --- /dev/null +++ b/tests/processor/test_pi05_processor.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Compare the PI0.5 processor pipeline against the vendored OpenPI reference processors.""" + +import os + +import pytest +import torch + +pytest.importorskip("transformers") + +from lerobot.configs import FeatureType, PolicyFeature # noqa: E402 +from lerobot.policies.pi05 import PI05Policy # noqa: E402 +from lerobot.policies.pi05.configuration_pi05 import PI05Config # noqa: E402 +from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402 +from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402 +from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402 + IMAGE_KEYS, + assert_processor_inputs_match_lerobot, + clone_batch, + make_openpi_observation_from_raw, + openpi_model_actions_from_raw, +) + +pytestmark = pytest.mark.skipif( + os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true", + reason="OpenPI processor parity uses the PaliGemma tokenizer; run manually outside CI.", +) + +DUMMY_ACTION_DIM = 32 +DUMMY_STATE_DIM = 32 +DUMMY_ACTION_HORIZON = 50 +DUMMY_MAX_TOKEN_LEN = 200 +DEVICE = torch.device("cpu") + +DUMMY_DATASET_STATS = { + OBS_STATE: { + "mean": torch.zeros(DUMMY_STATE_DIM), + "std": torch.ones(DUMMY_STATE_DIM), + "q01": torch.zeros(DUMMY_STATE_DIM), + "q99": torch.ones(DUMMY_STATE_DIM), + }, + ACTION: { + "mean": torch.zeros(DUMMY_ACTION_DIM), + "std": torch.ones(DUMMY_ACTION_DIM), + "q01": torch.zeros(DUMMY_ACTION_DIM), + "q99": torch.ones(DUMMY_ACTION_DIM), + }, + "images": { + key: { + "mean": torch.zeros(3, 224, 224), + "std": torch.ones(3, 224, 224), + "q01": torch.zeros(3, 224, 224), + "q99": torch.ones(3, 224, 224), + } + for key in IMAGE_KEYS + }, +} + + +class PI05PolicyInputAdapter(torch.nn.Module): + """Minimal adapter exposing PI0.5 policy image preparation without loading model weights.""" + + _preprocess_images = PI05Policy._preprocess_images + + def __init__(self, config: PI05Config) -> None: + super().__init__() + self.config = config + self._device_anchor = torch.nn.Parameter(torch.empty((), device=config.device), requires_grad=False) + + +def create_pi05_config() -> PI05Config: + config = PI05Config(device=str(DEVICE)) + config.max_state_dim = DUMMY_STATE_DIM + config.max_action_dim = DUMMY_ACTION_DIM + config.chunk_size = DUMMY_ACTION_HORIZON + config.n_action_steps = DUMMY_ACTION_HORIZON + config.tokenizer_max_length = DUMMY_MAX_TOKEN_LEN + config.input_features = { + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(DUMMY_STATE_DIM,)), + **{ + f"observation.images.{key}": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)) + for key in IMAGE_KEYS + }, + } + config.output_features = { + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(DUMMY_ACTION_DIM,)), + } + return config + + +def create_dummy_data() -> dict: + batch_size = 2 + prompt = "Pick up the red block and place it in the bin" + return { + OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE), + ACTION: torch.randn( + batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE + ), + **{ + f"observation.images.{key}": torch.rand( + batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE + ) + for key in IMAGE_KEYS + }, + "task": [prompt for _ in range(batch_size)], + } + + +def test_pi05_processor_inputs_match_openpi_reference(): + torch.manual_seed(0) + config = create_pi05_config() + preprocessor, _ = make_pi05_pre_post_processors(config=config, dataset_stats=DUMMY_DATASET_STATS) + + raw_batch = create_dummy_data() + lerobot_batch = preprocessor(clone_batch(raw_batch)) + openpi_observation = make_openpi_observation_from_raw( + raw_batch, + action_dim=DUMMY_ACTION_DIM, + max_token_len=DUMMY_MAX_TOKEN_LEN, + dataset_stats=DUMMY_DATASET_STATS, + pi05=True, + ) + + assert_processor_inputs_match_lerobot( + PI05PolicyInputAdapter(config), + lerobot_batch, + openpi_observation, + compare_state=False, + ) + torch.testing.assert_close( + lerobot_batch[ACTION], + openpi_model_actions_from_raw( + raw_batch, + action_dim=DUMMY_ACTION_DIM, + dataset_stats=DUMMY_DATASET_STATS, + pi05=True, + ), + rtol=0, + atol=0, + ) diff --git a/tests/processor/test_pi0_processor.py b/tests/processor/test_pi0_processor.py new file mode 100644 index 000000000..e9d5b4a37 --- /dev/null +++ b/tests/processor/test_pi0_processor.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Compare the PI0 processor pipeline against the vendored OpenPI reference processors.""" + +import os + +import pytest +import torch + +pytest.importorskip("transformers") + +from lerobot.configs import FeatureType, PolicyFeature # noqa: E402 +from lerobot.policies.pi0 import PI0Policy # noqa: E402 +from lerobot.policies.pi0.configuration_pi0 import PI0Config # noqa: E402 +from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors # noqa: E402 +from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402 +from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402 + IMAGE_KEYS, + assert_processor_inputs_match_lerobot, + clone_batch, + make_openpi_observation_from_raw, + openpi_model_actions_from_raw, +) + +pytestmark = pytest.mark.skipif( + os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true", + reason="OpenPI processor parity uses the PaliGemma tokenizer; run manually outside CI.", +) + +DUMMY_ACTION_DIM = 32 +DUMMY_STATE_DIM = 32 +DUMMY_ACTION_HORIZON = 50 +DUMMY_MAX_TOKEN_LEN = 48 +DEVICE = torch.device("cpu") + +DUMMY_DATASET_STATS = { + OBS_STATE: { + "mean": torch.zeros(DUMMY_STATE_DIM), + "std": torch.ones(DUMMY_STATE_DIM), + "q01": torch.zeros(DUMMY_STATE_DIM), + "q99": torch.ones(DUMMY_STATE_DIM), + }, + ACTION: { + "mean": torch.zeros(DUMMY_ACTION_DIM), + "std": torch.ones(DUMMY_ACTION_DIM), + "q01": torch.zeros(DUMMY_ACTION_DIM), + "q99": torch.ones(DUMMY_ACTION_DIM), + }, + "images": { + key: { + "mean": torch.zeros(3, 224, 224), + "std": torch.ones(3, 224, 224), + "q01": torch.zeros(3, 224, 224), + "q99": torch.ones(3, 224, 224), + } + for key in IMAGE_KEYS + }, +} + + +class PI0PolicyInputAdapter(torch.nn.Module): + """Minimal adapter exposing PI0 policy input-preparation helpers without loading model weights.""" + + _preprocess_images = PI0Policy._preprocess_images + prepare_state = PI0Policy.prepare_state + + def __init__(self, config: PI0Config) -> None: + super().__init__() + self.config = config + self._device_anchor = torch.nn.Parameter(torch.empty((), device=config.device), requires_grad=False) + + +def create_pi0_config() -> PI0Config: + config = PI0Config(device=str(DEVICE)) + config.max_state_dim = DUMMY_STATE_DIM + config.max_action_dim = DUMMY_ACTION_DIM + config.chunk_size = DUMMY_ACTION_HORIZON + config.n_action_steps = DUMMY_ACTION_HORIZON + config.tokenizer_max_length = DUMMY_MAX_TOKEN_LEN + config.input_features = { + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(DUMMY_STATE_DIM,)), + **{ + f"observation.images.{key}": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)) + for key in IMAGE_KEYS + }, + } + config.output_features = { + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(DUMMY_ACTION_DIM,)), + } + return config + + +def create_dummy_data() -> dict: + batch_size = 2 + prompt = "Pick up the red block and place it in the bin" + return { + OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE), + ACTION: torch.randn( + batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE + ), + **{ + f"observation.images.{key}": torch.rand( + batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE + ) + for key in IMAGE_KEYS + }, + "task": [prompt for _ in range(batch_size)], + } + + +def test_pi0_processor_inputs_match_openpi_reference(): + torch.manual_seed(0) + config = create_pi0_config() + preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=DUMMY_DATASET_STATS) + + raw_batch = create_dummy_data() + lerobot_batch = preprocessor(clone_batch(raw_batch)) + openpi_observation = make_openpi_observation_from_raw( + raw_batch, + action_dim=DUMMY_ACTION_DIM, + max_token_len=DUMMY_MAX_TOKEN_LEN, + dataset_stats=DUMMY_DATASET_STATS, + pi05=False, + ) + + assert_processor_inputs_match_lerobot( + PI0PolicyInputAdapter(config), + lerobot_batch, + openpi_observation, + compare_state=True, + ) + torch.testing.assert_close( + lerobot_batch[ACTION], + openpi_model_actions_from_raw( + raw_batch, + action_dim=DUMMY_ACTION_DIM, + dataset_stats=DUMMY_DATASET_STATS, + pi05=False, + ), + rtol=0, + atol=0, + ) From 9f437d86b6d74982c26b1ef499fa32d71cbe115f Mon Sep 17 00:00:00 2001 From: Haoming Song Date: Fri, 22 May 2026 16:31:04 +0800 Subject: [PATCH 4/7] fix(groot): align GR00TN15Config with transformers config dataclasses (#3606) * fix(gr00t): fix gr00t config dataclass init TypeError * fix(groot): guard strict config decorator without transformers for passing CI --------- Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> From 8194897994be649cc85405fe54eb9ede751886b1 Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Fri, 22 May 2026 12:03:07 +0200 Subject: [PATCH 5/7] fix(deps): cap placo below 0.9.16 and harden kinematics import (#3647) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(deps): cap placo below 0.9.16 and harden kinematics import placo 0.9.16 links against liburdfdom_sensor.so.4, which is unavailable on Ubuntu 24.04 (noble ships urdfdom 3.x). Importing placo on that base crashes with: ImportError: liburdfdom_sensor.so.4.0: cannot open shared object file This broke nightly Latest Deps tests (CPU and GPU) when the lockfile upgrade picked placo 0.9.16, since lerobot.model.kinematics unconditionally imports placo when _placo_available is true, and that check (importlib.util.find_spec) cannot detect dlopen failures of transitive shared libraries — so unrelated subsystems (RL actor, gym_manipulator) became unimportable. Two changes: 1. Pin placo to <0.9.16 in pyproject.toml + regenerate uv.lock (0.9.16 → 0.9.15). Short-term unblock for nightly CI until system urdfdom 4.x is broadly available. 2. Harden the import guard in src/lerobot/model/kinematics.py: wrap 'import placo' in try/except ImportError so a missing transitive .so no longer crashes module import. RobotKinematics instantiation now raises an informative ImportError citing the underlying dlopen failure via _raise_if_placo_unusable(). Co-Authored-By: Claude Opus 4.7 (1M context) * fix(kinematics): hoist _placo_runtime_error to module scope for mypy Mypy walks the TYPE_CHECKING branch in which the runtime else-block is not executed, so _placo_runtime_error was only defined at runtime and mypy reported 'Name "_placo_runtime_error" is not defined' on the three references inside _raise_if_placo_unusable. Declare the symbol unconditionally at module scope with a default of None; the runtime import-failure branch still assigns to it. Co-Authored-By: Claude Opus 4.7 (1M context) * style(kinematics): drop verbose comments around placo import guard Co-Authored-By: Claude Opus 4.7 (1M context) --------- Co-authored-by: Claude Opus 4.7 (1M context) --- pyproject.toml | 4 +++- src/lerobot/model/kinematics.py | 20 +++++++++++++++++--- uv.lock | 22 +++++++++++----------- 3 files changed, 31 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ca6248c95..5d182648c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -138,7 +138,9 @@ dataset_viz = ["lerobot[dataset]", "lerobot[viz]"] # Common av-dep = ["av>=15.0.0,<16.0.0"] pygame-dep = ["pygame>=2.5.1,<2.7.0"] -placo-dep = ["placo>=0.9.6,<0.9.17"] +# NOTE: 0.9.16 links against liburdfdom_sensor.so.4, which is unavailable on Ubuntu 24.04 +# (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available. +placo-dep = ["placo>=0.9.6,<0.9.16"] transformers-dep = ["transformers>=5.4.0,<5.6.0"] grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"] can-dep = ["python-can>=4.2.0,<5.0.0"] diff --git a/src/lerobot/model/kinematics.py b/src/lerobot/model/kinematics.py index 01705ded5..45bd7d438 100644 --- a/src/lerobot/model/kinematics.py +++ b/src/lerobot/model/kinematics.py @@ -18,12 +18,25 @@ from typing import TYPE_CHECKING import numpy as np -from lerobot.utils.import_utils import _placo_available, require_package +from lerobot.utils.import_utils import require_package -if TYPE_CHECKING or _placo_available: +_placo_runtime_error: ImportError | None = None + +if TYPE_CHECKING: import placo # type: ignore[import-not-found] else: - placo = None + try: + import placo # type: ignore[import-not-found] + except ImportError as _placo_import_err: + placo = None + _placo_runtime_error = _placo_import_err + + +def _raise_if_placo_unusable() -> None: + if placo is None and _placo_runtime_error is not None: + raise ImportError( + f"placo is installed but failed to import: {_placo_runtime_error!s}" + ) from _placo_runtime_error class RobotKinematics: @@ -44,6 +57,7 @@ class RobotKinematics: joint_names (list[str] | None): List of joint names to use for the kinematics solver """ require_package("placo", extra="placo-dep") + _raise_if_placo_unusable() self.robot = placo.RobotWrapper(urdf_path) self.solver = placo.KinematicsSolver(self.robot) diff --git a/uv.lock b/uv.lock index 7092f780a..c5f026517 100644 --- a/uv.lock +++ b/uv.lock @@ -3203,7 +3203,7 @@ requires-dist = [ { name = "pandas", marker = "extra == 'video-benchmark'", specifier = ">=2.2.2,<2.4.0" }, { name = "peft", marker = "extra == 'peft-dep'", specifier = ">=0.18.0,<1.0.0" }, { name = "pillow", specifier = ">=10.0.0,<13.0.0" }, - { name = "placo", marker = "extra == 'placo-dep'", specifier = ">=0.9.6,<0.9.17" }, + { name = "placo", marker = "extra == 'placo-dep'", specifier = ">=0.9.6,<0.9.16" }, { name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.7.0,<5.0.0" }, { name = "protobuf", marker = "extra == 'grpcio-dep'", specifier = ">=6.31.1,<6.32.0" }, { name = "pyarrow", marker = "extra == 'dataset'", specifier = ">=21.0.0,<30.0.0" }, @@ -4592,7 +4592,7 @@ wheels = [ [[package]] name = "placo" -version = "0.9.16" +version = "0.9.15" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cmeel" }, @@ -4602,16 +4602,16 @@ dependencies = [ { name = "pin" }, { name = "rhoban-cmeel-jsoncpp" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/9e/0a/36c5b729d0d69075e7dfafd1b36c4df6fbb8c1ff1585e88d3c56d4c15010/placo-0.9.16.tar.gz", hash = "sha256:5314faaf6442e7ffe17347680d236af953951813bbfb1c09c4a27f7388d332e4", size = 136871, upload-time = "2025-11-07T14:24:58.811Z" } +sdist = { url = "https://files.pythonhosted.org/packages/40/c4/a33a0ee2ad798471a1c43a96109d28f358fd95c78a56f8cff57acb66d2bc/placo-0.9.15.tar.gz", hash = "sha256:df47f1154bae305c943bd20ba4f56d50ffc65625efc98679fefb11e8ff3c462c", size = 136856, upload-time = "2025-11-03T10:49:13.151Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a4/95/8a85b58033303fd354a680e1494f47801abdca9133c222ae1c2473983f25/placo-0.9.16-0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:417a89920b340e3aec19f1f49e1fb06789c679a807450157af8bdf4aef4bc82b", size = 1641806, upload-time = "2025-11-07T14:24:34.736Z" }, - { url = "https://files.pythonhosted.org/packages/92/bd/2fb3556c71b0689b3168c0e85fce5befb605affcfe4afb3b5e7b5ba6749f/placo-0.9.16-0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a7ef7ac33ba889d2122db0d7ed55eeecdffed020e2282712989bb11e408bab40", size = 1515468, upload-time = "2025-11-07T14:24:36.587Z" }, - { url = "https://files.pythonhosted.org/packages/ea/fd/7dba380720dfb89df582a51d0b2cb43957a36849f676baa3dfc74704e67f/placo-0.9.16-0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:885773fe8a8e809022451ec16d47479562a042596f663b8c5bbe762cd616f573", size = 2106540, upload-time = "2025-11-07T14:24:38.149Z" }, - { url = "https://files.pythonhosted.org/packages/7a/40/97c7c799fe4f89111b973d7a5f86626a2ec1d0e6e20ce2988e0a2bda66f5/placo-0.9.16-0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:19f097305c714e539fbf19e761897f6daab2ff73f639319431b144e77dd3852e", size = 2178511, upload-time = "2025-11-07T14:24:40.04Z" }, - { url = "https://files.pythonhosted.org/packages/f7/4d/f1700aae269584477b5d72561d2fc5ace37b4bca167892a74a369849c67e/placo-0.9.16-0-cp313-cp313-macosx_10_9_x86_64.whl", hash = "sha256:be11fa987702114097ccf3d94e1c4a891796878429e25c8d88b187ecc652e7ae", size = 1641812, upload-time = "2025-11-07T14:24:41.308Z" }, - { url = "https://files.pythonhosted.org/packages/43/d7/21d1d0dd1311c0cbd9ccd233cdae520bbe2370095e3c831059d6077c90bd/placo-0.9.16-0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c2d65aeb4844eae28006ad3a50c8519b27c701912cc99c46c95e33ed049f3635", size = 1515457, upload-time = "2025-11-07T14:24:42.758Z" }, - { url = "https://files.pythonhosted.org/packages/0f/e8/939ba23bfa539fb90ab9ab1c2c59ff9a9a46e24699fc90e8ca3ff2948646/placo-0.9.16-0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:a7633aff1c592c1f45e86a174a372d5d7972673935cb9151391277ff49ec2072", size = 2106538, upload-time = "2025-11-07T14:24:44.517Z" }, - { url = "https://files.pythonhosted.org/packages/08/00/ad24cc0ad85fbe12267df28c2061e1eaef8f852146c467fcd7a681e11028/placo-0.9.16-0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:0d97a7284b65fc45aef27865c80cf7e53f04646d35bb18494ab62dfbbc9a35bd", size = 2178514, upload-time = "2025-11-07T14:24:45.994Z" }, + { url = "https://files.pythonhosted.org/packages/ef/03/207b1c087996b918fdbaa5a3a685e3b14b068cd303bf87affdf83f722b33/placo-0.9.15-0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:eab7a299e73291fe631c02448b9e9826539f4824e198bcf85f7c91fdd77d054b", size = 1641975, upload-time = "2025-11-03T10:48:48.887Z" }, + { url = "https://files.pythonhosted.org/packages/92/55/40432b26bb1c5b9e677fbc41e8d85b54fa8897b7daebb2a22d410b0a7f7b/placo-0.9.15-0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:23f9dd19b8d15fa9d86968948b57981ebc6f1decafeffc2d646d8b56f685b50d", size = 1515448, upload-time = "2025-11-03T10:48:50.562Z" }, + { url = "https://files.pythonhosted.org/packages/fd/8e/e6283201d329409dccf2045b5c1efd73b3dad5268143bbea4668029ca9c6/placo-0.9.15-0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:2680a2166c23a0a2aa6226ad75c63a2b2310c812673a5db296616d9af053e076", size = 2106550, upload-time = "2025-11-03T10:48:52.364Z" }, + { url = "https://files.pythonhosted.org/packages/51/c3/77efe4c999e1d80ec14879ef73ea2a2144aa12db2b67870a562f87ed5b43/placo-0.9.15-0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:1a2202a78bcd2874ca09a9a6526a95b38874803923cb9b3b4b96cd68ab4b7217", size = 2178531, upload-time = "2025-11-03T10:48:53.932Z" }, + { url = "https://files.pythonhosted.org/packages/fe/e7/b5cc5ad53414ff7af3357e0c9d97d902a3ce276e7810f8814fe9f0c1fb70/placo-0.9.15-0-cp313-cp313-macosx_10_9_x86_64.whl", hash = "sha256:84a445a99b059a512d1b4c64841a91d6f50149c7be9255c65bedeebbe6663989", size = 1641982, upload-time = "2025-11-03T10:48:55.277Z" }, + { url = "https://files.pythonhosted.org/packages/ad/1c/1c9163d941698a077617f218041efc573d3bf5a1c169a284112bd622fccd/placo-0.9.15-0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:b3106e7e6b05cbfa494239d8aa14795f7da8ee5dec851602f0d6297e311d7334", size = 1515447, upload-time = "2025-11-03T10:48:56.975Z" }, + { url = "https://files.pythonhosted.org/packages/cd/22/3d9b9045b89248c8476dd42243bc9821a123d9199e4e96a944124ad80cf1/placo-0.9.15-0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:66c3d099e87551401aace04f1293a3c3563b1399319976647846845bf92c3ccf", size = 2106558, upload-time = "2025-11-03T10:48:58.667Z" }, + { url = "https://files.pythonhosted.org/packages/20/0b/45dbdd2c378a7cece578b7344fda493d5a2aa6777089798a315ce4f97c22/placo-0.9.15-0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:0e06b7d3d618ddc2b649ab8b0b46db8001fe72fe2fbcc801524df0ccc8a3da40", size = 2178531, upload-time = "2025-11-03T10:49:00.533Z" }, ] [[package]] From f65f3f7a4a8bc2eb405d692ed297b9f9a3828e20 Mon Sep 17 00:00:00 2001 From: Reece O'Mahoney <66252930+reeceomahoney@users.noreply.github.com> Date: Tue, 26 May 2026 13:01:19 +0100 Subject: [PATCH 6/7] Fix policy.path in YAML configs (PR #3145 followup) (#3597) PR #3145 added YAML support for policy.path but left two bugs: 1. extract_path_fields_from_config only deleted config_data[field] when no sibling overrides existed. With siblings, the dict stayed in place and draccus crashed decoding it as PreTrainedConfig (no 'type' key). Sibling overrides go into _config_yaml_overrides and are applied later by from_pretrained(), so the field can always be removed. 2. wrap() updated config_path_cli to the cleaned temp file path but never propagated it to the draccus.parse fallback branch. cli_args still contained --config_path=, so draccus read the original YAML with path: still present. Tests passed because they (a) called extract_path_fields_from_config directly and (b) included type: alongside path: in the YAML, sidestepping both bugs. Co-authored-by: Steven Palma --- src/lerobot/configs/parser.py | 11 +++- tests/test_yaml_policy_path.py | 116 +++++++++++++++++++++++++++++++-- 2 files changed, 117 insertions(+), 10 deletions(-) diff --git a/src/lerobot/configs/parser.py b/src/lerobot/configs/parser.py index d55fa44aa..46cff2b48 100644 --- a/src/lerobot/configs/parser.py +++ b/src/lerobot/configs/parser.py @@ -255,8 +255,7 @@ def extract_path_fields_from_config(config_path: str, path_fields: list[str]) -> remaining = config_data[field] if remaining: _config_yaml_overrides[field] = _flatten_to_cli_args(remaining) - else: - del config_data[field] + del config_data[field] modified = True if not modified: @@ -311,7 +310,13 @@ def wrap(config_path: Path | None = None) -> Callable[[F], F]: cli_args = filter_arg("config_path", cli_args) cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args) else: - cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args) + if config_path_cli: + cli_args = filter_arg("config_path", cli_args) + cfg = draccus.parse( + config_class=argtype, + config_path=config_path_cli or config_path, + args=cli_args, + ) response = fn(cfg, *args, **kwargs) return response diff --git a/tests/test_yaml_policy_path.py b/tests/test_yaml_policy_path.py index 710a71c9a..8d8f7f2ec 100644 --- a/tests/test_yaml_policy_path.py +++ b/tests/test_yaml_policy_path.py @@ -1,10 +1,14 @@ """Tests for policy.path support in YAML config files (issue #2957).""" import json +import sys import tempfile +from dataclasses import dataclass, field +from unittest.mock import patch import yaml +from lerobot.configs import parser from lerobot.configs.parser import ( _config_path_args, _config_yaml_overrides, @@ -16,7 +20,8 @@ from lerobot.configs.parser import ( def test_extract_path_fields_from_yaml(): - """Test that policy.path is extracted from a YAML config and removed.""" + """Test that policy.path is extracted from a YAML config and the policy block + is removed entirely (siblings are captured separately as cli_overrides).""" config = { "dataset": {"repo_id": "lerobot/pusht"}, "policy": {"type": "smolvla", "path": "lerobot/smolvla_base", "push_to_hub": False}, @@ -26,26 +31,33 @@ def test_extract_path_fields_from_yaml(): config_path = f.name _config_path_args.clear() + _config_yaml_overrides.clear() cleaned_path = extract_path_fields_from_config(config_path, ["policy"]) # Path should be extracted and stored assert _config_path_args["policy"] == "lerobot/smolvla_base" - # Cleaned config should not have the path field + # Cleaned config should not have the policy block at all -- draccus must not + # try to decode it as PreTrainedConfig; the actual config comes from + # from_pretrained(path) with the captured overrides applied on top. with open(cleaned_path) as f: cleaned = yaml.safe_load(f) - assert "path" not in cleaned["policy"] - assert cleaned["policy"]["type"] == "smolvla" - assert cleaned["policy"]["push_to_hub"] is False + assert "policy" not in cleaned # Original dataset should be untouched assert cleaned["dataset"]["repo_id"] == "lerobot/pusht" + # Sibling overrides (excluding type/path) captured for from_pretrained. + overrides = get_yaml_overrides("policy") + assert any("push_to_hub=false" in o for o in overrides) + _config_path_args.clear() + _config_yaml_overrides.clear() def test_extract_path_fields_from_json(): - """Test that policy.path is extracted from a JSON config.""" + """Test that policy.path is extracted from a JSON config and the policy + block is removed entirely.""" config = { "policy": {"type": "act", "path": "some/local/path"}, } @@ -54,15 +66,17 @@ def test_extract_path_fields_from_json(): config_path = f.name _config_path_args.clear() + _config_yaml_overrides.clear() cleaned_path = extract_path_fields_from_config(config_path, ["policy"]) assert _config_path_args["policy"] == "some/local/path" with open(cleaned_path) as f: cleaned = json.load(f) - assert "path" not in cleaned["policy"] + assert "policy" not in cleaned _config_path_args.clear() + _config_yaml_overrides.clear() def test_extract_no_path_returns_original(): @@ -216,3 +230,91 @@ def test_flatten_nested_with_bools(): args = _flatten_to_cli_args(d) assert "--optimizer.use_warmup=true" in args assert "--optimizer.lr=0.01" in args + + +def test_extract_removes_field_with_siblings_and_no_type(): + """Regression: when policy.path has siblings but no type:, the entire policy + block must still be removed from the cleaned config. Otherwise draccus tries + to decode the leftover dict as PreTrainedConfig and crashes on the missing + type discriminator. + """ + config = { + "dataset": {"repo_id": "lerobot/pusht"}, + "policy": { + "path": "lerobot/smolvla_base", + "n_action_steps": 10, + "dtype": "bfloat16", + }, + } + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(config, f) + config_path = f.name + + _config_path_args.clear() + _config_yaml_overrides.clear() + cleaned_path = extract_path_fields_from_config(config_path, ["policy"]) + + with open(cleaned_path) as f: + cleaned = yaml.safe_load(f) or {} + assert "policy" not in cleaned, "policy block should be fully removed when path is present" + assert cleaned["dataset"]["repo_id"] == "lerobot/pusht" + assert _config_path_args["policy"] == "lerobot/smolvla_base" + overrides = get_yaml_overrides("policy") + assert any("n_action_steps=10" in o for o in overrides) + assert any("dtype=bfloat16" in o for o in overrides) + + _config_path_args.clear() + _config_yaml_overrides.clear() + + +@dataclass +class _DummyNested: + foo: int = 0 + + +@dataclass +class _DummyConfig: + nested: _DummyNested = field(default_factory=_DummyNested) + other: str = "default" + + @classmethod + def __get_path_fields__(cls): + return ["nested"] + + +def test_wrap_uses_cleaned_config_for_draccus_parse(): + """Regression: wrap() updates config_path_cli to point at the cleaned temp + file but must propagate that to the draccus.parse fallback branch. Without + the fix, cli_args still contains --config_path= and draccus reads + the original YAML with `path:` still in it, crashing on the unknown field. + """ + config = { + "nested": {"path": "some/checkpoint", "foo": 42}, + "other": "set-via-yaml", + } + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(config, f) + config_path = f.name + + _config_path_args.clear() + _config_yaml_overrides.clear() + + captured: dict = {} + + @parser.wrap() + def main(cfg: _DummyConfig) -> _DummyConfig: + captured["cfg"] = cfg + return cfg + + with patch.object(sys, "argv", ["prog", f"--config_path={config_path}"]): + main() + + assert captured["cfg"].other == "set-via-yaml" + assert _config_path_args["nested"] == "some/checkpoint" + # Cleaned config dropped `nested:` entirely; defaults stand for this wrapper + # class (a real PreTrainedConfig would now load the checkpoint and apply + # the captured yaml_overrides via from_pretrained()). + assert captured["cfg"].nested.foo == 0 + + _config_path_args.clear() + _config_yaml_overrides.clear() From 5c98e80430d4a747926b45893568e388105a2400 Mon Sep 17 00:00:00 2001 From: Haoming Song Date: Tue, 26 May 2026 20:04:22 +0800 Subject: [PATCH 7/7] fix(gr00t): fix Eagle25VL model and processor crash in transformers>=5.4.0, <5.6.0 (#3652) Co-authored-by: Steven Palma --- .../policies/groot/eagle2_hg_model/modeling_eagle2_5_vl.py | 1 + .../groot/eagle2_hg_model/processing_eagle2_5_vl.py | 1 - src/lerobot/policies/groot/processor_groot.py | 6 +++++- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/lerobot/policies/groot/eagle2_hg_model/modeling_eagle2_5_vl.py b/src/lerobot/policies/groot/eagle2_hg_model/modeling_eagle2_5_vl.py index 5a66cfbce..6e5532ea4 100755 --- a/src/lerobot/policies/groot/eagle2_hg_model/modeling_eagle2_5_vl.py +++ b/src/lerobot/policies/groot/eagle2_hg_model/modeling_eagle2_5_vl.py @@ -60,6 +60,7 @@ class Eagle25VLPreTrainedModel(PreTrainedModel): "SiglipEncoderLayer", ] _skip_keys_device_placement = "past_key_values" + _supports_flash_attn = True _supports_flash_attn_2 = True _supports_cache_class = True _supports_static_cache = True diff --git a/src/lerobot/policies/groot/eagle2_hg_model/processing_eagle2_5_vl.py b/src/lerobot/policies/groot/eagle2_hg_model/processing_eagle2_5_vl.py index 7b1f67fef..b36e70c47 100755 --- a/src/lerobot/policies/groot/eagle2_hg_model/processing_eagle2_5_vl.py +++ b/src/lerobot/policies/groot/eagle2_hg_model/processing_eagle2_5_vl.py @@ -124,7 +124,6 @@ class Eagle25VLProcessor(ProcessorMixin): "videos_kwargs", "text_kwargs", ] - image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" def __init__( diff --git a/src/lerobot/policies/groot/processor_groot.py b/src/lerobot/policies/groot/processor_groot.py index 3367de711..6848c7c84 100644 --- a/src/lerobot/policies/groot/processor_groot.py +++ b/src/lerobot/policies/groot/processor_groot.py @@ -206,7 +206,11 @@ def _build_eagle_processor(tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS "Vendor files are copied during model creation. Create the policy/model first, " "or call ensure_eagle_cache_ready() before building processors." ) - proc = AutoProcessor.from_pretrained(str(cache_dir), trust_remote_code=True, use_fast=True) + proc = AutoProcessor.from_pretrained( + str(cache_dir), + trust_remote_code=True, + fix_mistral_regex=False, + ) proc.tokenizer.padding_side = "left" return proc