diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index da66ac400..2e04aad82 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -149,6 +149,10 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return SmolVLAConfig(**kwargs) elif policy_type == "reward_classifier": return RewardClassifierConfig(**kwargs) + elif policy_type == "pi0_openpi": + return PI0OpenPIConfig(**kwargs) + elif policy_type == "pi05_openpi": + return PI05OpenPIConfig(**kwargs) else: raise ValueError(f"Policy type '{policy_type}' is not available.") @@ -268,6 +272,22 @@ def make_pre_post_processors( dataset_stats=kwargs.get("dataset_stats"), ) + elif isinstance(policy_cfg, PI0OpenPIConfig): + from lerobot.policies.pi0_openpi.processor_pi0_openpi import make_pi0_openpi_pre_post_processors + + processors = make_pi0_openpi_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + + elif isinstance(policy_cfg, PI05OpenPIConfig): + from lerobot.policies.pi05_openpi.processor_pi05openpi import make_pi05_openpi_pre_post_processors + + processors = make_pi05_openpi_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + elif isinstance(policy_cfg, SACConfig): from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors diff --git a/src/lerobot/policies/pi0/__init__.py b/src/lerobot/policies/pi0/__init__.py index 12d766633..449461f93 100644 --- a/src/lerobot/policies/pi0/__init__.py +++ b/src/lerobot/policies/pi0/__init__.py @@ -16,5 +16,6 @@ from .configuration_pi0openpi import PI0OpenPIConfig from .modeling_pi0openpi import PI0OpenPIPolicy +from .processor_pi0_openpi import make_pi0_openpi_pre_post_processors -__all__ = ["PI0OpenPIConfig", "PI0OpenPIPolicy"] +__all__ = ["PI0OpenPIConfig", "PI0OpenPIPolicy", "make_pi0_openpi_pre_post_processors"] diff --git a/src/lerobot/policies/pi0/modeling_pi0openpi.py b/src/lerobot/policies/pi0/modeling_pi0openpi.py index 7be238889..c6ea2895c 100644 --- a/src/lerobot/policies/pi0/modeling_pi0openpi.py +++ b/src/lerobot/policies/pi0/modeling_pi0openpi.py @@ -24,16 +24,14 @@ from typing import Literal import torch import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn -from transformers import AutoTokenizer from transformers.models.auto import CONFIG_MAPPING from transformers.models.gemma import modeling_gemma from transformers.models.gemma.modeling_gemma import GemmaForCausalLM from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration from lerobot.configs.policies import PreTrainedConfig -from lerobot.constants import ACTION, OBS_STATE -from lerobot.policies.normalize import Normalize, Unnormalize -from lerobot.policies.pi0.configuration_pi0openpi import PI0OpenPIConfig +from lerobot.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE +from lerobot.policies.pi0_openpi.configuration_pi0openpi import PI0OpenPIConfig from lerobot.policies.pretrained import PreTrainedPolicy, T @@ -50,7 +48,7 @@ def get_safe_dtype(target_dtype, device_type): # see openpi `get_safe_dtype` (e def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy) - time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu" + time: torch.Tensor, dimension: int, min_period: float, max_period: float, device="cpu" ) -> Tensor: """Computes sine-cosine positional embedding vectors for scalar positions.""" if dimension % 2 != 0: @@ -851,31 +849,15 @@ class PI0OpenPIPolicy(PreTrainedPolicy): def __init__( # see lerobot pi0 `__init__` self, config: PI0OpenPIConfig, - dataset_stats: dict[str, dict[str, Tensor]] | None = None, ): """ Args: config: Policy configuration class instance. - dataset_stats: Dataset statistics to be used for normalization. """ super().__init__(config) config.validate_features() self.config = config - self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) - self.normalize_targets = Normalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.unnormalize_outputs = Unnormalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - - # Create tokenizer for language input - self.tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224") - - # Set max token length for tokenizer (from OpenPI) - self.max_token_len = config.tokenizer_max_length - # Initialize the core PI0 model self.model = PI0Pytorch(config) @@ -965,10 +947,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy): remap_count = 0 for key, value in fixed_state_dict.items(): - if not key.startswith("model.") and not any( - key.startswith(prefix) - for prefix in ["normalize_inputs.", "normalize_targets.", "unnormalize_outputs."] - ): + if not key.startswith("model."): new_key = f"model.{key}" remapped_state_dict[new_key] = value remap_count += 1 @@ -1143,44 +1122,6 @@ class PI0OpenPIPolicy(PreTrainedPolicy): return images, img_masks - def _tokenize_language( - self, batch: dict[str, Tensor] - ) -> tuple[Tensor, Tensor]: # see lerobot pi0 `prepare_language` - """Tokenize language input using PaliGemma tokenizer.""" - device = next(self.parameters()).device - - # Get task description - if "task" in batch: - tasks = batch["task"] - if isinstance(tasks, str): - tasks = [tasks] - elif isinstance(tasks, list) and len(tasks) == 1: - # Expand to batch size - batch_size = batch[next(iter(batch.keys()))].shape[0] - tasks = tasks * batch_size - else: - # Default task if not provided - batch_size = batch[next(iter(batch.keys()))].shape[0] - tasks = ["Pick up the object"] * batch_size - - # PaliGemma prompt has to end with a new line - tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks] - - # Tokenize with max_length padding to match OpenPI's expected format - tokenized = self.tokenizer( - tasks, - padding="max_length", # Use max_length padding as per OpenPI - padding_side="right", # from lerobot pi0 `prepare_language` - truncation=True, - max_length=self.max_token_len, # Use the max token length from config - return_tensors="pt", - ) - - lang_tokens = tokenized["input_ids"].to(device) - lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool) - - return lang_tokens, lang_masks - def prepare_state(self, batch): # see lerobot pi0 `prepare_state` (exact copy) """Pad state""" state = pad_vector(batch[OBS_STATE], self.config.max_state_dim) @@ -1209,11 +1150,9 @@ class PI0OpenPIPolicy(PreTrainedPolicy): """Predict a chunk of actions given environment observations.""" self.eval() - batch = self.normalize_inputs(batch) - # Prepare inputs images, img_masks = self._preprocess_images(batch) - lang_tokens, lang_masks = self._tokenize_language(batch) + lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] state = self.prepare_state(batch) # Sample actions using the model @@ -1223,17 +1162,14 @@ class PI0OpenPIPolicy(PreTrainedPolicy): original_action_dim = self.config.output_features[ACTION].shape[0] actions = actions[:, :, :original_action_dim] - actions = self.unnormalize_outputs({ACTION: actions})[ACTION] return actions def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: # see lerobot pi0 `forward` """Run the batch through the model and compute the loss for training.""" - batch = self.normalize_inputs(batch) - batch = self.normalize_targets(batch) # Prepare inputs images, img_masks = self._preprocess_images(batch) - lang_tokens, lang_masks = self._tokenize_language(batch) + lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] state = self.prepare_state(batch) actions = self.prepare_action(batch) diff --git a/src/lerobot/policies/pi0/processor_pi05openpi.py b/src/lerobot/policies/pi0/processor_pi05openpi.py new file mode 100644 index 000000000..9f85db23c --- /dev/null +++ b/src/lerobot/policies/pi0/processor_pi05openpi.py @@ -0,0 +1,147 @@ +from copy import deepcopy +from typing import Any + +import numpy as np +import torch + +from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.constants import OBS_STATE, POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME +from lerobot.policies.pi05_openpi.configuration_pi05openpi import PI05OpenPIConfig +from lerobot.policies.pi05_openpi.modeling_pi05openpi import pad_vector +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + ProcessorStep, + ProcessorStepRegistry, + RenameObservationsProcessorStep, + TokenizerProcessorStep, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action +from lerobot.processor.core import EnvTransition, TransitionKey + + +@ProcessorStepRegistry.register(name="pi05_prepare_state_tokenizer_processor_step") +class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep): + """ + Processor step to prepare the state and tokenize the language input. + """ + + max_state_dim: int + task_key: str = "task" + + def __call__(self, transition: EnvTransition) -> EnvTransition: + transition = transition.copy() + + state = transition.get(TransitionKey.OBSERVATION, {}).get(OBS_STATE) + if state is None: + raise ValueError("State is required for PI05") + tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key) + if tasks is None: + raise ValueError("No task found in complementary data") + + # TODO: check if this necessary + state = deepcopy(state) + + # Prepare state (pad to max_state_dim) + state = pad_vector(state, self.max_state_dim) + + # Normalize state to [-1, 1] range if needed (assuming it's already normalized from normalize_inputs) + # Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`) + state_np = state.cpu().numpy() + discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 + + full_prompts = [] + for i, task in enumerate(tasks): + cleaned_text = task.strip().replace("_", " ").replace("\n", " ") + state_str = " ".join(map(str, discretized_states[i])) + full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: " + full_prompts.append(full_prompt) + + transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts + # Normalize state to [-1, 1] range if needed (assuming it's already normalized from normalize_inputs) + # Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`) + return transition + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + This step does not alter the feature definitions. + """ + return features + + +def make_pi05_openpi_pre_post_processors( + config: PI05OpenPIConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """ + Constructs pre-processor and post-processor pipelines for the PI0 policy. + + The pre-processing pipeline prepares input data for the model by: + 1. Renaming features to match pretrained configurations. + 2. Normalizing input and output features based on dataset statistics. + 3. Adding a batch dimension. + 4. Appending a newline character to the task description for tokenizer compatibility. + 5. Tokenizing the text prompt using the PaliGemma tokenizer. + 6. Moving all data to the specified device. + + The post-processing pipeline handles the model's output by: + 1. Moving data to the CPU. + 2. Unnormalizing the output features to their original scale. + + Args: + config: The configuration object for the PI0 policy. + dataset_stats: A dictionary of statistics for normalization. + preprocessor_kwargs: Additional arguments for the pre-processor pipeline. + postprocessor_kwargs: Additional arguments for the post-processor pipeline. + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines. + """ + + # Add remaining processors + input_steps: list[ProcessorStep] = [ + RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one + AddBatchDimensionProcessorStep(), + Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim), + TokenizerProcessorStep( + tokenizer_name="google/paligemma-3b-pt-224", + max_length=config.tokenizer_max_length, + padding_side="right", + padding="max_length", + ), + DeviceProcessorStep(device=config.device), + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + ] + + output_steps: list[ProcessorStep] = [ + UnnormalizerProcessorStep( + features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + ), + DeviceProcessorStep(device="cpu"), + ] + + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) diff --git a/src/lerobot/policies/pi05/__init__.py b/src/lerobot/policies/pi05/__init__.py index 2b438db85..46f737d0c 100644 --- a/src/lerobot/policies/pi05/__init__.py +++ b/src/lerobot/policies/pi05/__init__.py @@ -16,5 +16,6 @@ from .configuration_pi05openpi import PI05OpenPIConfig from .modeling_pi05openpi import PI05OpenPIPolicy +from .processor_pi05openpi import make_pi05_openpi_pre_post_processors -__all__ = ["PI05OpenPIConfig", "PI05OpenPIPolicy"] +__all__ = ["PI05OpenPIConfig", "PI05OpenPIPolicy", "make_pi05_openpi_pre_post_processors"] diff --git a/src/lerobot/policies/pi05/modeling_pi05openpi.py b/src/lerobot/policies/pi05/modeling_pi05openpi.py index eb6f95934..e4043e6ca 100644 --- a/src/lerobot/policies/pi05/modeling_pi05openpi.py +++ b/src/lerobot/policies/pi05/modeling_pi05openpi.py @@ -19,22 +19,19 @@ import logging import math from collections import deque from pathlib import Path -from typing import Any, Literal +from typing import Literal -import numpy as np import torch import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn -from transformers import AutoTokenizer from transformers.models.auto import CONFIG_MAPPING from transformers.models.gemma import modeling_gemma from transformers.models.gemma.modeling_gemma import GemmaForCausalLM from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration from lerobot.configs.policies import PreTrainedConfig -from lerobot.constants import ACTION, OBS_STATE -from lerobot.policies.normalize import Normalize, Unnormalize -from lerobot.policies.pi05.configuration_pi05openpi import PI05OpenPIConfig +from lerobot.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS +from lerobot.policies.pi05_openpi.configuration_pi05openpi import PI05OpenPIConfig from lerobot.policies.pretrained import PreTrainedPolicy, T @@ -53,7 +50,7 @@ def get_safe_dtype(target_dtype, device_type): # see openpi `get_safe_dtype` (e def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy) - time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu" + time: torch.Tensor, dimension: int, min_period: float, max_period: float, device="cpu" ) -> Tensor: """Computes sine-cosine positional embedding vectors for scalar positions.""" if dimension % 2 != 0: @@ -825,31 +822,15 @@ class PI05OpenPIPolicy(PreTrainedPolicy): def __init__( # see lerobot pi0 `__init__` self, config: PI05OpenPIConfig, - dataset_stats: dict[str, dict[str, Tensor]] | None = None, ): """ Args: config: Policy configuration class instance. - dataset_stats: Dataset statistics to be used for normalization. """ super().__init__(config) config.validate_features() self.config = config - self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) - self.normalize_targets = Normalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.unnormalize_outputs = Unnormalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - - # Create tokenizer for language input - self.tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224") - - # Set max token length for tokenizer (from OpenPI) - self.max_token_len = config.tokenizer_max_length - # Initialize the core PI05 model self.model = PI05Pytorch(config) @@ -939,10 +920,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy): remap_count = 0 for key, value in fixed_state_dict.items(): - if not key.startswith("model.") and not any( - key.startswith(prefix) - for prefix in ["normalize_inputs.", "normalize_targets.", "unnormalize_outputs."] - ): + if not key.startswith("model."): new_key = f"model.{key}" remapped_state_dict[new_key] = value remap_count += 1 @@ -1121,63 +1099,6 @@ class PI05OpenPIPolicy(PreTrainedPolicy): return images, img_masks - def _tokenize_language_and_state( - self, batch: dict[str, Tensor] - ) -> tuple[Tensor, Tensor]: # see lerobot pi0 `prepare_language` - """Tokenize language input using PaliGemma tokenizer.""" - device = next(self.parameters()).device - - # Get task description - if "task" in batch: - tasks = batch["task"] - if isinstance(tasks, str): - tasks = [tasks] - elif isinstance(tasks, list) and len(tasks) == 1: - # Expand to batch size - batch_size = batch[next(iter(batch.keys()))].shape[0] - tasks = tasks * batch_size - else: - # Default task if not provided - batch_size = batch[next(iter(batch.keys()))].shape[0] - tasks = ["Pick up the object"] * batch_size - - # Handle discrete state input for PI05 (always the case for pi05) - # Get state from batch and discretize it - state: Any | None = batch.get(OBS_STATE) - if state is None: - raise ValueError("Robot state is required for PI05") - - # Prepare state (pad to max_state_dim) - state = pad_vector(state, self.config.max_state_dim) - - # Normalize state to [-1, 1] range if needed (assuming it's already normalized from normalize_inputs) - # Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`) - state_np = state.cpu().numpy() - discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 - - # Create full prompts with state included (see openpi `PaligemmaTokenizer.tokenize()`) - full_prompts = [] - for i, task in enumerate(tasks): - cleaned_text = task.strip().replace("_", " ").replace("\n", " ") - state_str = " ".join(map(str, discretized_states[i])) - full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: " - full_prompts.append(full_prompt) - - # Tokenize the full prompts with state - tokenized = self.tokenizer( - full_prompts, - padding="max_length", - padding_side="right", - truncation=True, - max_length=self.max_token_len, - return_tensors="pt", - ) - - tokens = tokenized["input_ids"].to(device) - masks = tokenized["attention_mask"].to(device, dtype=torch.bool) - - return tokens, masks - def prepare_action(self, batch): # see lerobot pi0 `prepare_action` (exact copy) """Pad action""" actions = pad_vector(batch[ACTION], self.config.max_action_dim) @@ -1201,11 +1122,9 @@ class PI05OpenPIPolicy(PreTrainedPolicy): """Predict a chunk of actions given environment observations.""" self.eval() - batch = self.normalize_inputs(batch) - # Prepare inputs images, img_masks = self._preprocess_images(batch) - tokens, masks = self._tokenize_language_and_state(batch) # State is included in tokens for PI05 + tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] # Sample actions using the model (no separate state needed for PI05) actions = self.model.sample_actions(images, img_masks, tokens, masks) @@ -1214,17 +1133,14 @@ class PI05OpenPIPolicy(PreTrainedPolicy): original_action_dim = self.config.output_features[ACTION].shape[0] actions = actions[:, :, :original_action_dim] - actions = self.unnormalize_outputs({ACTION: actions})[ACTION] return actions def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: # see lerobot pi0 `forward` """Run the batch through the model and compute the loss for training.""" - batch = self.normalize_inputs(batch) - batch = self.normalize_targets(batch) # Prepare inputs images, img_masks = self._preprocess_images(batch) - tokens, masks = self._tokenize_language_and_state(batch) # State is included in tokens for PI05 + tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] actions = self.prepare_action(batch) diff --git a/src/lerobot/policies/pi05/processor_pi0_openpi.py b/src/lerobot/policies/pi05/processor_pi0_openpi.py new file mode 100644 index 000000000..14f148d92 --- /dev/null +++ b/src/lerobot/policies/pi05/processor_pi0_openpi.py @@ -0,0 +1,164 @@ +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME +from lerobot.policies.pi0_openpi.configuration_pi0openpi import PI0OpenPIConfig +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + ComplementaryDataProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + ProcessorStep, + ProcessorStepRegistry, + RenameObservationsProcessorStep, + TokenizerProcessorStep, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action + + +@ProcessorStepRegistry.register(name="pi0_openpi_new_line_processor") +class Pi0OpenPINewLineProcessor(ComplementaryDataProcessorStep): + """ + Ensures that the task description string ends with a newline character. + + This processing step is required for compatibility with the PaliGemma tokenizer, + which expects a newline at the end of the text prompt. It handles both single + strings and lists of strings for the 'task' key in complementary data. + """ + + def complementary_data(self, complementary_data): + """ + Adds a newline to the 'task' field if it doesn't already have one. + + Args: + complementary_data: A dictionary that may contain a 'task' key with a + string or list of strings. + + Returns: + A new dictionary with the modified 'task' field. + """ + if "task" not in complementary_data: + return complementary_data + + task = complementary_data["task"] + if task is None: + return complementary_data + + new_complementary_data = dict(complementary_data) + + # Handle both string and list of strings + if isinstance(task, str): + # Single string: add newline if not present + if not task.endswith("\n"): + new_complementary_data["task"] = f"{task}\n" + elif isinstance(task, list) and all(isinstance(t, str) for t in task): + # List of strings: add newline to each if not present + new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task] + # If task is neither string nor list of strings, leave unchanged + + return new_complementary_data + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + This step does not alter the feature definitions. + + Args: + features: The input feature dictionary. + + Returns: + The unchanged feature dictionary. + """ + return features + + +def make_pi0_openpi_pre_post_processors( + config: PI0OpenPIConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """ + Constructs pre-processor and post-processor pipelines for the PI0 policy. + + The pre-processing pipeline prepares input data for the model by: + 1. Renaming features to match pretrained configurations. + 2. Normalizing input and output features based on dataset statistics. + 3. Adding a batch dimension. + 4. Appending a newline character to the task description for tokenizer compatibility. + 5. Tokenizing the text prompt using the PaliGemma tokenizer. + 6. Moving all data to the specified device. + + The post-processing pipeline handles the model's output by: + 1. Moving data to the CPU. + 2. Unnormalizing the output features to their original scale. + + Args: + config: The configuration object for the PI0 policy. + dataset_stats: A dictionary of statistics for normalization. + preprocessor_kwargs: Additional arguments for the pre-processor pipeline. + postprocessor_kwargs: Additional arguments for the post-processor pipeline. + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines. + """ + + # Add remaining processors + input_steps: list[ProcessorStep] = [ + RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one + AddBatchDimensionProcessorStep(), + Pi0OpenPINewLineProcessor(), # Add newlines before tokenization for PaliGemma + TokenizerProcessorStep( + tokenizer_name="google/paligemma-3b-pt-224", + max_length=config.tokenizer_max_length, + padding_side="right", + padding="max_length", + ), + DeviceProcessorStep(device=config.device), + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + ] + + output_steps: list[ProcessorStep] = [ + UnnormalizerProcessorStep( + features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + ), + DeviceProcessorStep(device="cpu"), + ] + + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) diff --git a/src/lerobot/processor/migrate_policy_normalization.py b/src/lerobot/processor/migrate_policy_normalization.py index 131f799d6..ed80cfcfa 100644 --- a/src/lerobot/processor/migrate_policy_normalization.py +++ b/src/lerobot/processor/migrate_policy_normalization.py @@ -302,6 +302,65 @@ def clean_state_dict( return new_state_dict +def load_state_dict_with_missing_key_handling( + policy: torch.nn.Module, + state_dict: dict[str, torch.Tensor], + policy_type: str, + known_missing_keys_whitelist: dict[str, list[str]], +) -> list[str]: + """ + Load state dict into policy with graceful handling of missing keys. + + This function loads the state dict with strict=False, filters out whitelisted + missing keys, and provides detailed reporting about any issues found. + + Args: + policy: The policy model to load the state dict into. + state_dict: The cleaned state dictionary to load. + policy_type: The type of policy (used for whitelist lookup). + known_missing_keys_whitelist: Dictionary mapping policy types to lists of + known acceptable missing keys. + + Returns: + List of problematic missing keys that weren't in the whitelist. + """ + # Load the cleaned state dict with strict=False to capture missing/unexpected keys + load_result = policy.load_state_dict(state_dict, strict=False) + + # Check for missing keys + missing_keys = load_result.missing_keys + unexpected_keys = load_result.unexpected_keys + + # Filter out whitelisted missing keys + policy_type_lower = policy_type.lower() + whitelisted_keys = known_missing_keys_whitelist.get(policy_type_lower, []) + problematic_missing_keys = [key for key in missing_keys if key not in whitelisted_keys] + + if missing_keys: + if problematic_missing_keys: + print(f"⚠️ WARNING: Found {len(problematic_missing_keys)} unexpected missing keys:") + for key in problematic_missing_keys: + print(f" - {key}") + + if len(missing_keys) > len(problematic_missing_keys): + whitelisted_missing = [key for key in missing_keys if key in whitelisted_keys] + print(f"ℹ️ INFO: Found {len(whitelisted_missing)} expected missing keys (whitelisted):") + for key in whitelisted_missing: + print(f" - {key}") + + if unexpected_keys: + print(f"⚠️ WARNING: Found {len(unexpected_keys)} unexpected keys:") + for key in unexpected_keys: + print(f" - {key}") + + if not missing_keys and not unexpected_keys: + print("✅ Successfully loaded cleaned state dict into policy model (all keys matched)") + else: + print("⚠️ State dict loaded with some missing/unexpected keys (see details above)") + + return problematic_missing_keys + + def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[str, PolicyFeature]: """ Converts a feature dictionary from the old config format to the new `PolicyFeature` format. @@ -335,9 +394,45 @@ def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[ return converted_features +def display_migration_summary_with_warnings(problematic_missing_keys: list[str]) -> None: + """ + Display final migration summary with warnings about problematic missing keys. + + Args: + problematic_missing_keys: List of missing keys that weren't in the whitelist. + """ + if not problematic_missing_keys: + return + + print("\n" + "=" * 60) + print("🚨 IMPORTANT: MIGRATION COMPLETED WITH WARNINGS") + print("=" * 60) + print( + f"The migration was successful, but {len(problematic_missing_keys)} unexpected missing keys were found:" + ) + print() + for key in problematic_missing_keys: + print(f" ❌ {key}") + print() + print("These missing keys may indicate:") + print(" • The model architecture has changed") + print(" • Some components were not properly saved in the original model") + print(" • The migration script needs to be updated for this policy type") + print() + print("What to do next:") + print(" 1. Test your migrated model carefully to ensure it works as expected") + print(" 2. If you encounter issues, please open an issue at:") + print(" https://github.com/huggingface/lerobot/issues") + print(" 3. Include this migration log and the missing keys listed above") + print() + print("If the model works correctly despite these warnings, the missing keys") + print("might be expected for your policy type and can be added to the whitelist.") + print("=" * 60) + + def load_model_from_hub( repo_id: str, revision: str | None = None -) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: +) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any] | None]: """ Downloads and loads a model's state_dict and configs from the Hugging Face Hub. @@ -347,13 +442,12 @@ def load_model_from_hub( Returns: A tuple containing the model's state dictionary, the policy configuration, - and the training configuration. + and the training configuration (None if train_config.json is not found). """ # Download files. safetensors_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision) config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision) - train_config_path = hf_hub_download(repo_id=repo_id, filename="train_config.json", revision=revision) # Load state_dict state_dict = load_safetensors(safetensors_path) @@ -362,8 +456,14 @@ def load_model_from_hub( with open(config_path) as f: config = json.load(f) - with open(train_config_path) as f: - train_config = json.load(f) + # Try to load train_config (optional) + train_config = None + try: + train_config_path = hf_hub_download(repo_id=repo_id, filename="train_config.json", revision=revision) + with open(train_config_path) as f: + train_config = json.load(f) + except FileNotFoundError: + print("train_config.json not found - continuing without training configuration") return state_dict, config, train_config @@ -409,8 +509,15 @@ def main(): state_dict = load_safetensors(os.path.join(args.pretrained_path, "model.safetensors")) with open(os.path.join(args.pretrained_path, "config.json")) as f: config = json.load(f) - with open(os.path.join(args.pretrained_path, "train_config.json")) as f: - train_config = json.load(f) + + # Try to load train_config (optional) + train_config = None + train_config_path = os.path.join(args.pretrained_path, "train_config.json") + if os.path.exists(train_config_path): + with open(train_config_path) as f: + train_config = json.load(f) + else: + print("train_config.json not found - continuing without training configuration") else: # Hub repository state_dict, config, train_config = load_model_from_hub(args.pretrained_path, args.revision) @@ -487,10 +594,20 @@ def main(): policy_class = get_policy_class(policy_type) policy = policy_class(policy_config) - # Load the cleaned state dict - policy.load_state_dict(new_state_dict, strict=True) - print("Successfully loaded cleaned state dict into policy model") + # Define whitelist of known missing keys that are acceptable (for example weight tie) for certain policy types + known_missing_keys_whitelist = { + "pi0": ["model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"], + # Add other policy types and their known missing keys here as needed + } + # Load state dict with graceful missing key handling + problematic_missing_keys = load_state_dict_with_missing_key_handling( + policy=policy, + state_dict=new_state_dict, + policy_type=policy_type, + known_missing_keys_whitelist=known_missing_keys_whitelist, + ) + policy.to(torch.float32) # Create preprocessor and postprocessor using the factory print("Creating preprocessor and postprocessor using make_pre_post_processors...") preprocessor, postprocessor = make_pre_post_processors(policy_cfg=policy_config, dataset_stats=stats) @@ -520,7 +637,9 @@ def main(): # Generate and save model card print("Generating model card...") # Get metadata from original config - dataset_repo_id = train_config.get("repo_id", "unknown") + dataset_repo_id = "unknown" + if train_config is not None: + dataset_repo_id = train_config.get("repo_id", "unknown") license = config.get("license", "apache-2.0") tags = config.get("tags", ["robotics", "lerobot", policy_type]) or ["robotics", "lerobot", policy_type] @@ -641,6 +760,9 @@ final_action = postprocessor(action) else: print(f"\nView the changes at: https://huggingface.co/{hub_repo_id}") + # Display final summary about any problematic missing keys + display_migration_summary_with_warnings(problematic_missing_keys) + if __name__ == "__main__": main()