mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 20:01:25 +00:00
feat (overrides): Implement support for loading processors with parameter overrides
- Added the ability to provide non-serializable objects when loading processors from saved configurations using the `overrides` parameter. - Enhanced error handling for invalid override keys and instantiation errors. - Updated documentation and examples to illustrate the usage of overrides for both registered and unregistered steps. - Added comprehensive tests to validate the new functionality and ensure backward compatibility.
This commit is contained in:
@@ -433,8 +433,53 @@ class RobotProcessor(ModelHubMixin):
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, source: str) -> RobotProcessor:
|
||||
"""Load a serialized processor from source (local path or Hugging Face Hub identifier)."""
|
||||
def from_pretrained(cls, source: str, *, overrides: dict[str, Any] | None = None) -> RobotProcessor:
|
||||
"""Load a serialized processor from source (local path or Hugging Face Hub identifier).
|
||||
|
||||
Args:
|
||||
source: Local path to a saved processor directory or Hugging Face Hub identifier
|
||||
(e.g., "username/processor-name").
|
||||
overrides: Optional dictionary mapping step names to configuration overrides.
|
||||
Keys must match exact step class names (for unregistered steps) or registry names
|
||||
(for registered steps). Values are dictionaries containing parameter overrides
|
||||
that will be merged with the saved configuration. This is useful for providing
|
||||
non-serializable objects like environment instances.
|
||||
|
||||
Returns:
|
||||
A RobotProcessor instance loaded from the saved configuration.
|
||||
|
||||
Raises:
|
||||
ImportError: If a processor step class cannot be loaded or imported.
|
||||
ValueError: If a step cannot be instantiated with the provided configuration.
|
||||
KeyError: If an override key doesn't match any step in the saved configuration.
|
||||
|
||||
Examples:
|
||||
Basic loading:
|
||||
```python
|
||||
processor = RobotProcessor.from_pretrained("path/to/processor")
|
||||
```
|
||||
|
||||
Loading with overrides for non-serializable objects:
|
||||
```python
|
||||
import gym
|
||||
env = gym.make("CartPole-v1")
|
||||
processor = RobotProcessor.from_pretrained(
|
||||
"username/cartpole-processor",
|
||||
overrides={"ActionRepeatStep": {"env": env}}
|
||||
)
|
||||
```
|
||||
|
||||
Multiple overrides:
|
||||
```python
|
||||
processor = RobotProcessor.from_pretrained(
|
||||
"path/to/processor",
|
||||
overrides={
|
||||
"CustomStep": {"param1": "new_value"},
|
||||
"device_processor": {"device": "cuda:1"} # For registered steps
|
||||
}
|
||||
)
|
||||
```
|
||||
"""
|
||||
if Path(source).is_dir():
|
||||
# Local path - use it directly
|
||||
base_path = Path(source)
|
||||
@@ -450,6 +495,13 @@ class RobotProcessor(ModelHubMixin):
|
||||
# Store downloaded files in the same directory as the config
|
||||
base_path = Path(config_path).parent
|
||||
|
||||
# Handle None overrides
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
|
||||
# Validate that all override keys will be matched
|
||||
override_keys = set(overrides.keys())
|
||||
|
||||
steps: list[ProcessorStep] = []
|
||||
for step_entry in config["steps"]:
|
||||
# Check if step uses registry name or module path
|
||||
@@ -457,6 +509,7 @@ class RobotProcessor(ModelHubMixin):
|
||||
# Load from registry
|
||||
try:
|
||||
step_class = ProcessorStepRegistry.get(step_entry["registry_name"])
|
||||
step_key = step_entry["registry_name"]
|
||||
except KeyError as e:
|
||||
raise ImportError(f"Failed to load processor step from registry. {str(e)}") from e
|
||||
else:
|
||||
@@ -468,6 +521,7 @@ class RobotProcessor(ModelHubMixin):
|
||||
try:
|
||||
module = importlib.import_module(module_path)
|
||||
step_class = getattr(module, class_name)
|
||||
step_key = class_name
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise ImportError(
|
||||
f"Failed to load processor step '{full_class_path}'. "
|
||||
@@ -478,7 +532,15 @@ class RobotProcessor(ModelHubMixin):
|
||||
|
||||
# Instantiate the step with its config
|
||||
try:
|
||||
step_instance: ProcessorStep = step_class(**step_entry.get("config", {}))
|
||||
saved_cfg = step_entry.get("config", {})
|
||||
step_overrides = overrides.get(step_key, {})
|
||||
merged_cfg = {**saved_cfg, **step_overrides}
|
||||
step_instance: ProcessorStep = step_class(**merged_cfg)
|
||||
|
||||
# Track which override keys were used
|
||||
if step_key in override_keys:
|
||||
override_keys.discard(step_key)
|
||||
|
||||
except Exception as e:
|
||||
step_name = step_entry.get("registry_name", step_entry.get("class", "Unknown"))
|
||||
raise ValueError(
|
||||
@@ -499,6 +561,23 @@ class RobotProcessor(ModelHubMixin):
|
||||
|
||||
steps.append(step_instance)
|
||||
|
||||
# Check for unused override keys
|
||||
if override_keys:
|
||||
available_keys = []
|
||||
for step_entry in config["steps"]:
|
||||
if "registry_name" in step_entry:
|
||||
available_keys.append(step_entry["registry_name"])
|
||||
else:
|
||||
full_class_path = step_entry["class"]
|
||||
class_name = full_class_path.rsplit(".", 1)[1]
|
||||
available_keys.append(class_name)
|
||||
|
||||
raise KeyError(
|
||||
f"Override keys {list(override_keys)} do not match any step in the saved configuration. "
|
||||
f"Available step keys: {available_keys}. "
|
||||
f"Make sure override keys match exact step class names or registry names."
|
||||
)
|
||||
|
||||
return cls(steps, config.get("name", "RobotProcessor"), config.get("seed"))
|
||||
|
||||
def __len__(self) -> int:
|
||||
|
||||
Reference in New Issue
Block a user