feat(processor): convert openpi model with processor

This commit is contained in:
AdilZouitine
2025-09-19 15:48:35 +02:00
parent d691d1e4fe
commit 10f5ea854f
8 changed files with 481 additions and 174 deletions

View File

@@ -149,6 +149,10 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return SmolVLAConfig(**kwargs)
elif policy_type == "reward_classifier":
return RewardClassifierConfig(**kwargs)
elif policy_type == "pi0_openpi":
return PI0OpenPIConfig(**kwargs)
elif policy_type == "pi05_openpi":
return PI05OpenPIConfig(**kwargs)
else:
raise ValueError(f"Policy type '{policy_type}' is not available.")
@@ -268,6 +272,22 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, PI0OpenPIConfig):
from lerobot.policies.pi0_openpi.processor_pi0_openpi import make_pi0_openpi_pre_post_processors
processors = make_pi0_openpi_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, PI05OpenPIConfig):
from lerobot.policies.pi05_openpi.processor_pi05openpi import make_pi05_openpi_pre_post_processors
processors = make_pi05_openpi_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, SACConfig):
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors

View File

@@ -16,5 +16,6 @@
from .configuration_pi0openpi import PI0OpenPIConfig
from .modeling_pi0openpi import PI0OpenPIPolicy
from .processor_pi0_openpi import make_pi0_openpi_pre_post_processors
__all__ = ["PI0OpenPIConfig", "PI0OpenPIPolicy"]
__all__ = ["PI0OpenPIConfig", "PI0OpenPIPolicy", "make_pi0_openpi_pre_post_processors"]

View File

@@ -24,16 +24,14 @@ from typing import Literal
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from transformers import AutoTokenizer
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
from lerobot.configs.policies import PreTrainedConfig
from lerobot.constants import ACTION, OBS_STATE
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pi0.configuration_pi0openpi import PI0OpenPIConfig
from lerobot.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE
from lerobot.policies.pi0_openpi.configuration_pi0openpi import PI0OpenPIConfig
from lerobot.policies.pretrained import PreTrainedPolicy, T
@@ -50,7 +48,7 @@ def get_safe_dtype(target_dtype, device_type): # see openpi `get_safe_dtype` (e
def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy)
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
time: torch.Tensor, dimension: int, min_period: float, max_period: float, device="cpu"
) -> Tensor:
"""Computes sine-cosine positional embedding vectors for scalar positions."""
if dimension % 2 != 0:
@@ -851,31 +849,15 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
def __init__( # see lerobot pi0 `__init__`
self,
config: PI0OpenPIConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance.
dataset_stats: Dataset statistics to be used for normalization.
"""
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
# Create tokenizer for language input
self.tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
# Set max token length for tokenizer (from OpenPI)
self.max_token_len = config.tokenizer_max_length
# Initialize the core PI0 model
self.model = PI0Pytorch(config)
@@ -965,10 +947,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
remap_count = 0
for key, value in fixed_state_dict.items():
if not key.startswith("model.") and not any(
key.startswith(prefix)
for prefix in ["normalize_inputs.", "normalize_targets.", "unnormalize_outputs."]
):
if not key.startswith("model."):
new_key = f"model.{key}"
remapped_state_dict[new_key] = value
remap_count += 1
@@ -1143,44 +1122,6 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
return images, img_masks
def _tokenize_language(
self, batch: dict[str, Tensor]
) -> tuple[Tensor, Tensor]: # see lerobot pi0 `prepare_language`
"""Tokenize language input using PaliGemma tokenizer."""
device = next(self.parameters()).device
# Get task description
if "task" in batch:
tasks = batch["task"]
if isinstance(tasks, str):
tasks = [tasks]
elif isinstance(tasks, list) and len(tasks) == 1:
# Expand to batch size
batch_size = batch[next(iter(batch.keys()))].shape[0]
tasks = tasks * batch_size
else:
# Default task if not provided
batch_size = batch[next(iter(batch.keys()))].shape[0]
tasks = ["Pick up the object"] * batch_size
# PaliGemma prompt has to end with a new line
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
# Tokenize with max_length padding to match OpenPI's expected format
tokenized = self.tokenizer(
tasks,
padding="max_length", # Use max_length padding as per OpenPI
padding_side="right", # from lerobot pi0 `prepare_language`
truncation=True,
max_length=self.max_token_len, # Use the max token length from config
return_tensors="pt",
)
lang_tokens = tokenized["input_ids"].to(device)
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
return lang_tokens, lang_masks
def prepare_state(self, batch): # see lerobot pi0 `prepare_state` (exact copy)
"""Pad state"""
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
@@ -1209,11 +1150,9 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
"""Predict a chunk of actions given environment observations."""
self.eval()
batch = self.normalize_inputs(batch)
# Prepare inputs
images, img_masks = self._preprocess_images(batch)
lang_tokens, lang_masks = self._tokenize_language(batch)
lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
state = self.prepare_state(batch)
# Sample actions using the model
@@ -1223,17 +1162,14 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
original_action_dim = self.config.output_features[ACTION].shape[0]
actions = actions[:, :, :original_action_dim]
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
return actions
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: # see lerobot pi0 `forward`
"""Run the batch through the model and compute the loss for training."""
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
# Prepare inputs
images, img_masks = self._preprocess_images(batch)
lang_tokens, lang_masks = self._tokenize_language(batch)
lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
state = self.prepare_state(batch)
actions = self.prepare_action(batch)

View File

@@ -0,0 +1,147 @@
from copy import deepcopy
from typing import Any
import numpy as np
import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.constants import OBS_STATE, POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.pi05_openpi.configuration_pi05openpi import PI05OpenPIConfig
from lerobot.policies.pi05_openpi.modeling_pi05openpi import pad_vector
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.processor.core import EnvTransition, TransitionKey
@ProcessorStepRegistry.register(name="pi05_prepare_state_tokenizer_processor_step")
class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
"""
Processor step to prepare the state and tokenize the language input.
"""
max_state_dim: int
task_key: str = "task"
def __call__(self, transition: EnvTransition) -> EnvTransition:
transition = transition.copy()
state = transition.get(TransitionKey.OBSERVATION, {}).get(OBS_STATE)
if state is None:
raise ValueError("State is required for PI05")
tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key)
if tasks is None:
raise ValueError("No task found in complementary data")
# TODO: check if this necessary
state = deepcopy(state)
# Prepare state (pad to max_state_dim)
state = pad_vector(state, self.max_state_dim)
# Normalize state to [-1, 1] range if needed (assuming it's already normalized from normalize_inputs)
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
state_np = state.cpu().numpy()
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
full_prompts = []
for i, task in enumerate(tasks):
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
state_str = " ".join(map(str, discretized_states[i]))
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
full_prompts.append(full_prompt)
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts
# Normalize state to [-1, 1] range if needed (assuming it's already normalized from normalize_inputs)
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
return transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
This step does not alter the feature definitions.
"""
return features
def make_pi05_openpi_pre_post_processors(
config: PI05OpenPIConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Constructs pre-processor and post-processor pipelines for the PI0 policy.
The pre-processing pipeline prepares input data for the model by:
1. Renaming features to match pretrained configurations.
2. Normalizing input and output features based on dataset statistics.
3. Adding a batch dimension.
4. Appending a newline character to the task description for tokenizer compatibility.
5. Tokenizing the text prompt using the PaliGemma tokenizer.
6. Moving all data to the specified device.
The post-processing pipeline handles the model's output by:
1. Moving data to the CPU.
2. Unnormalizing the output features to their original scale.
Args:
config: The configuration object for the PI0 policy.
dataset_stats: A dictionary of statistics for normalization.
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
# Add remaining processors
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(),
Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim),
TokenizerProcessorStep(
tokenizer_name="google/paligemma-3b-pt-224",
max_length=config.tokenizer_max_length,
padding_side="right",
padding="max_length",
),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
]
output_steps: list[ProcessorStep] = [
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)

View File

@@ -16,5 +16,6 @@
from .configuration_pi05openpi import PI05OpenPIConfig
from .modeling_pi05openpi import PI05OpenPIPolicy
from .processor_pi05openpi import make_pi05_openpi_pre_post_processors
__all__ = ["PI05OpenPIConfig", "PI05OpenPIPolicy"]
__all__ = ["PI05OpenPIConfig", "PI05OpenPIPolicy", "make_pi05_openpi_pre_post_processors"]

View File

@@ -19,22 +19,19 @@ import logging
import math
from collections import deque
from pathlib import Path
from typing import Any, Literal
from typing import Literal
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from transformers import AutoTokenizer
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
from lerobot.configs.policies import PreTrainedConfig
from lerobot.constants import ACTION, OBS_STATE
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pi05.configuration_pi05openpi import PI05OpenPIConfig
from lerobot.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
from lerobot.policies.pi05_openpi.configuration_pi05openpi import PI05OpenPIConfig
from lerobot.policies.pretrained import PreTrainedPolicy, T
@@ -53,7 +50,7 @@ def get_safe_dtype(target_dtype, device_type): # see openpi `get_safe_dtype` (e
def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy)
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
time: torch.Tensor, dimension: int, min_period: float, max_period: float, device="cpu"
) -> Tensor:
"""Computes sine-cosine positional embedding vectors for scalar positions."""
if dimension % 2 != 0:
@@ -825,31 +822,15 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
def __init__( # see lerobot pi0 `__init__`
self,
config: PI05OpenPIConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance.
dataset_stats: Dataset statistics to be used for normalization.
"""
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
# Create tokenizer for language input
self.tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
# Set max token length for tokenizer (from OpenPI)
self.max_token_len = config.tokenizer_max_length
# Initialize the core PI05 model
self.model = PI05Pytorch(config)
@@ -939,10 +920,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
remap_count = 0
for key, value in fixed_state_dict.items():
if not key.startswith("model.") and not any(
key.startswith(prefix)
for prefix in ["normalize_inputs.", "normalize_targets.", "unnormalize_outputs."]
):
if not key.startswith("model."):
new_key = f"model.{key}"
remapped_state_dict[new_key] = value
remap_count += 1
@@ -1121,63 +1099,6 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
return images, img_masks
def _tokenize_language_and_state(
self, batch: dict[str, Tensor]
) -> tuple[Tensor, Tensor]: # see lerobot pi0 `prepare_language`
"""Tokenize language input using PaliGemma tokenizer."""
device = next(self.parameters()).device
# Get task description
if "task" in batch:
tasks = batch["task"]
if isinstance(tasks, str):
tasks = [tasks]
elif isinstance(tasks, list) and len(tasks) == 1:
# Expand to batch size
batch_size = batch[next(iter(batch.keys()))].shape[0]
tasks = tasks * batch_size
else:
# Default task if not provided
batch_size = batch[next(iter(batch.keys()))].shape[0]
tasks = ["Pick up the object"] * batch_size
# Handle discrete state input for PI05 (always the case for pi05)
# Get state from batch and discretize it
state: Any | None = batch.get(OBS_STATE)
if state is None:
raise ValueError("Robot state is required for PI05")
# Prepare state (pad to max_state_dim)
state = pad_vector(state, self.config.max_state_dim)
# Normalize state to [-1, 1] range if needed (assuming it's already normalized from normalize_inputs)
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
state_np = state.cpu().numpy()
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
# Create full prompts with state included (see openpi `PaligemmaTokenizer.tokenize()`)
full_prompts = []
for i, task in enumerate(tasks):
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
state_str = " ".join(map(str, discretized_states[i]))
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
full_prompts.append(full_prompt)
# Tokenize the full prompts with state
tokenized = self.tokenizer(
full_prompts,
padding="max_length",
padding_side="right",
truncation=True,
max_length=self.max_token_len,
return_tensors="pt",
)
tokens = tokenized["input_ids"].to(device)
masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
return tokens, masks
def prepare_action(self, batch): # see lerobot pi0 `prepare_action` (exact copy)
"""Pad action"""
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
@@ -1201,11 +1122,9 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
"""Predict a chunk of actions given environment observations."""
self.eval()
batch = self.normalize_inputs(batch)
# Prepare inputs
images, img_masks = self._preprocess_images(batch)
tokens, masks = self._tokenize_language_and_state(batch) # State is included in tokens for PI05
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
# Sample actions using the model (no separate state needed for PI05)
actions = self.model.sample_actions(images, img_masks, tokens, masks)
@@ -1214,17 +1133,14 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
original_action_dim = self.config.output_features[ACTION].shape[0]
actions = actions[:, :, :original_action_dim]
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
return actions
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: # see lerobot pi0 `forward`
"""Run the batch through the model and compute the loss for training."""
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
# Prepare inputs
images, img_masks = self._preprocess_images(batch)
tokens, masks = self._tokenize_language_and_state(batch) # State is included in tokens for PI05
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
actions = self.prepare_action(batch)

View File

@@ -0,0 +1,164 @@
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.pi0_openpi.configuration_pi0openpi import PI0OpenPIConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
ComplementaryDataProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
@ProcessorStepRegistry.register(name="pi0_openpi_new_line_processor")
class Pi0OpenPINewLineProcessor(ComplementaryDataProcessorStep):
"""
Ensures that the task description string ends with a newline character.
This processing step is required for compatibility with the PaliGemma tokenizer,
which expects a newline at the end of the text prompt. It handles both single
strings and lists of strings for the 'task' key in complementary data.
"""
def complementary_data(self, complementary_data):
"""
Adds a newline to the 'task' field if it doesn't already have one.
Args:
complementary_data: A dictionary that may contain a 'task' key with a
string or list of strings.
Returns:
A new dictionary with the modified 'task' field.
"""
if "task" not in complementary_data:
return complementary_data
task = complementary_data["task"]
if task is None:
return complementary_data
new_complementary_data = dict(complementary_data)
# Handle both string and list of strings
if isinstance(task, str):
# Single string: add newline if not present
if not task.endswith("\n"):
new_complementary_data["task"] = f"{task}\n"
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
# List of strings: add newline to each if not present
new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
# If task is neither string nor list of strings, leave unchanged
return new_complementary_data
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
This step does not alter the feature definitions.
Args:
features: The input feature dictionary.
Returns:
The unchanged feature dictionary.
"""
return features
def make_pi0_openpi_pre_post_processors(
config: PI0OpenPIConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Constructs pre-processor and post-processor pipelines for the PI0 policy.
The pre-processing pipeline prepares input data for the model by:
1. Renaming features to match pretrained configurations.
2. Normalizing input and output features based on dataset statistics.
3. Adding a batch dimension.
4. Appending a newline character to the task description for tokenizer compatibility.
5. Tokenizing the text prompt using the PaliGemma tokenizer.
6. Moving all data to the specified device.
The post-processing pipeline handles the model's output by:
1. Moving data to the CPU.
2. Unnormalizing the output features to their original scale.
Args:
config: The configuration object for the PI0 policy.
dataset_stats: A dictionary of statistics for normalization.
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
# Add remaining processors
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(),
Pi0OpenPINewLineProcessor(), # Add newlines before tokenization for PaliGemma
TokenizerProcessorStep(
tokenizer_name="google/paligemma-3b-pt-224",
max_length=config.tokenizer_max_length,
padding_side="right",
padding="max_length",
),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
]
output_steps: list[ProcessorStep] = [
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)

View File

@@ -302,6 +302,65 @@ def clean_state_dict(
return new_state_dict
def load_state_dict_with_missing_key_handling(
policy: torch.nn.Module,
state_dict: dict[str, torch.Tensor],
policy_type: str,
known_missing_keys_whitelist: dict[str, list[str]],
) -> list[str]:
"""
Load state dict into policy with graceful handling of missing keys.
This function loads the state dict with strict=False, filters out whitelisted
missing keys, and provides detailed reporting about any issues found.
Args:
policy: The policy model to load the state dict into.
state_dict: The cleaned state dictionary to load.
policy_type: The type of policy (used for whitelist lookup).
known_missing_keys_whitelist: Dictionary mapping policy types to lists of
known acceptable missing keys.
Returns:
List of problematic missing keys that weren't in the whitelist.
"""
# Load the cleaned state dict with strict=False to capture missing/unexpected keys
load_result = policy.load_state_dict(state_dict, strict=False)
# Check for missing keys
missing_keys = load_result.missing_keys
unexpected_keys = load_result.unexpected_keys
# Filter out whitelisted missing keys
policy_type_lower = policy_type.lower()
whitelisted_keys = known_missing_keys_whitelist.get(policy_type_lower, [])
problematic_missing_keys = [key for key in missing_keys if key not in whitelisted_keys]
if missing_keys:
if problematic_missing_keys:
print(f"⚠️ WARNING: Found {len(problematic_missing_keys)} unexpected missing keys:")
for key in problematic_missing_keys:
print(f" - {key}")
if len(missing_keys) > len(problematic_missing_keys):
whitelisted_missing = [key for key in missing_keys if key in whitelisted_keys]
print(f" INFO: Found {len(whitelisted_missing)} expected missing keys (whitelisted):")
for key in whitelisted_missing:
print(f" - {key}")
if unexpected_keys:
print(f"⚠️ WARNING: Found {len(unexpected_keys)} unexpected keys:")
for key in unexpected_keys:
print(f" - {key}")
if not missing_keys and not unexpected_keys:
print("✅ Successfully loaded cleaned state dict into policy model (all keys matched)")
else:
print("⚠️ State dict loaded with some missing/unexpected keys (see details above)")
return problematic_missing_keys
def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[str, PolicyFeature]:
"""
Converts a feature dictionary from the old config format to the new `PolicyFeature` format.
@@ -335,9 +394,45 @@ def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[
return converted_features
def display_migration_summary_with_warnings(problematic_missing_keys: list[str]) -> None:
"""
Display final migration summary with warnings about problematic missing keys.
Args:
problematic_missing_keys: List of missing keys that weren't in the whitelist.
"""
if not problematic_missing_keys:
return
print("\n" + "=" * 60)
print("🚨 IMPORTANT: MIGRATION COMPLETED WITH WARNINGS")
print("=" * 60)
print(
f"The migration was successful, but {len(problematic_missing_keys)} unexpected missing keys were found:"
)
print()
for key in problematic_missing_keys:
print(f"{key}")
print()
print("These missing keys may indicate:")
print(" • The model architecture has changed")
print(" • Some components were not properly saved in the original model")
print(" • The migration script needs to be updated for this policy type")
print()
print("What to do next:")
print(" 1. Test your migrated model carefully to ensure it works as expected")
print(" 2. If you encounter issues, please open an issue at:")
print(" https://github.com/huggingface/lerobot/issues")
print(" 3. Include this migration log and the missing keys listed above")
print()
print("If the model works correctly despite these warnings, the missing keys")
print("might be expected for your policy type and can be added to the whitelist.")
print("=" * 60)
def load_model_from_hub(
repo_id: str, revision: str | None = None
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any] | None]:
"""
Downloads and loads a model's state_dict and configs from the Hugging Face Hub.
@@ -347,13 +442,12 @@ def load_model_from_hub(
Returns:
A tuple containing the model's state dictionary, the policy configuration,
and the training configuration.
and the training configuration (None if train_config.json is not found).
"""
# Download files.
safetensors_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)
config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision)
train_config_path = hf_hub_download(repo_id=repo_id, filename="train_config.json", revision=revision)
# Load state_dict
state_dict = load_safetensors(safetensors_path)
@@ -362,8 +456,14 @@ def load_model_from_hub(
with open(config_path) as f:
config = json.load(f)
with open(train_config_path) as f:
train_config = json.load(f)
# Try to load train_config (optional)
train_config = None
try:
train_config_path = hf_hub_download(repo_id=repo_id, filename="train_config.json", revision=revision)
with open(train_config_path) as f:
train_config = json.load(f)
except FileNotFoundError:
print("train_config.json not found - continuing without training configuration")
return state_dict, config, train_config
@@ -409,8 +509,15 @@ def main():
state_dict = load_safetensors(os.path.join(args.pretrained_path, "model.safetensors"))
with open(os.path.join(args.pretrained_path, "config.json")) as f:
config = json.load(f)
with open(os.path.join(args.pretrained_path, "train_config.json")) as f:
train_config = json.load(f)
# Try to load train_config (optional)
train_config = None
train_config_path = os.path.join(args.pretrained_path, "train_config.json")
if os.path.exists(train_config_path):
with open(train_config_path) as f:
train_config = json.load(f)
else:
print("train_config.json not found - continuing without training configuration")
else:
# Hub repository
state_dict, config, train_config = load_model_from_hub(args.pretrained_path, args.revision)
@@ -487,10 +594,20 @@ def main():
policy_class = get_policy_class(policy_type)
policy = policy_class(policy_config)
# Load the cleaned state dict
policy.load_state_dict(new_state_dict, strict=True)
print("Successfully loaded cleaned state dict into policy model")
# Define whitelist of known missing keys that are acceptable (for example weight tie) for certain policy types
known_missing_keys_whitelist = {
"pi0": ["model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"],
# Add other policy types and their known missing keys here as needed
}
# Load state dict with graceful missing key handling
problematic_missing_keys = load_state_dict_with_missing_key_handling(
policy=policy,
state_dict=new_state_dict,
policy_type=policy_type,
known_missing_keys_whitelist=known_missing_keys_whitelist,
)
policy.to(torch.float32)
# Create preprocessor and postprocessor using the factory
print("Creating preprocessor and postprocessor using make_pre_post_processors...")
preprocessor, postprocessor = make_pre_post_processors(policy_cfg=policy_config, dataset_stats=stats)
@@ -520,7 +637,9 @@ def main():
# Generate and save model card
print("Generating model card...")
# Get metadata from original config
dataset_repo_id = train_config.get("repo_id", "unknown")
dataset_repo_id = "unknown"
if train_config is not None:
dataset_repo_id = train_config.get("repo_id", "unknown")
license = config.get("license", "apache-2.0")
tags = config.get("tags", ["robotics", "lerobot", policy_type]) or ["robotics", "lerobot", policy_type]
@@ -641,6 +760,9 @@ final_action = postprocessor(action)
else:
print(f"\nView the changes at: https://huggingface.co/{hub_repo_id}")
# Display final summary about any problematic missing keys
display_migration_summary_with_warnings(problematic_missing_keys)
if __name__ == "__main__":
main()