feat (overrides): Implement support for loading processors with parameter overrides

- Added the ability to provide non-serializable objects when loading processors from saved configurations using the `overrides` parameter.
- Enhanced error handling for invalid override keys and instantiation errors.
- Updated documentation and examples to illustrate the usage of overrides for both registered and unregistered steps.
- Added comprehensive tests to validate the new functionality and ensure backward compatibility.
This commit is contained in:
Adil Zouitine
2025-07-07 12:01:34 +02:00
parent 1c56779dd9
commit 3b8a3a32a0
3 changed files with 882 additions and 5 deletions

View File

@@ -433,8 +433,53 @@ class RobotProcessor(ModelHubMixin):
return self
@classmethod
def from_pretrained(cls, source: str) -> RobotProcessor:
"""Load a serialized processor from source (local path or Hugging Face Hub identifier)."""
def from_pretrained(cls, source: str, *, overrides: dict[str, Any] | None = None) -> RobotProcessor:
"""Load a serialized processor from source (local path or Hugging Face Hub identifier).
Args:
source: Local path to a saved processor directory or Hugging Face Hub identifier
(e.g., "username/processor-name").
overrides: Optional dictionary mapping step names to configuration overrides.
Keys must match exact step class names (for unregistered steps) or registry names
(for registered steps). Values are dictionaries containing parameter overrides
that will be merged with the saved configuration. This is useful for providing
non-serializable objects like environment instances.
Returns:
A RobotProcessor instance loaded from the saved configuration.
Raises:
ImportError: If a processor step class cannot be loaded or imported.
ValueError: If a step cannot be instantiated with the provided configuration.
KeyError: If an override key doesn't match any step in the saved configuration.
Examples:
Basic loading:
```python
processor = RobotProcessor.from_pretrained("path/to/processor")
```
Loading with overrides for non-serializable objects:
```python
import gym
env = gym.make("CartPole-v1")
processor = RobotProcessor.from_pretrained(
"username/cartpole-processor",
overrides={"ActionRepeatStep": {"env": env}}
)
```
Multiple overrides:
```python
processor = RobotProcessor.from_pretrained(
"path/to/processor",
overrides={
"CustomStep": {"param1": "new_value"},
"device_processor": {"device": "cuda:1"} # For registered steps
}
)
```
"""
if Path(source).is_dir():
# Local path - use it directly
base_path = Path(source)
@@ -450,6 +495,13 @@ class RobotProcessor(ModelHubMixin):
# Store downloaded files in the same directory as the config
base_path = Path(config_path).parent
# Handle None overrides
if overrides is None:
overrides = {}
# Validate that all override keys will be matched
override_keys = set(overrides.keys())
steps: list[ProcessorStep] = []
for step_entry in config["steps"]:
# Check if step uses registry name or module path
@@ -457,6 +509,7 @@ class RobotProcessor(ModelHubMixin):
# Load from registry
try:
step_class = ProcessorStepRegistry.get(step_entry["registry_name"])
step_key = step_entry["registry_name"]
except KeyError as e:
raise ImportError(f"Failed to load processor step from registry. {str(e)}") from e
else:
@@ -468,6 +521,7 @@ class RobotProcessor(ModelHubMixin):
try:
module = importlib.import_module(module_path)
step_class = getattr(module, class_name)
step_key = class_name
except (ImportError, AttributeError) as e:
raise ImportError(
f"Failed to load processor step '{full_class_path}'. "
@@ -478,7 +532,15 @@ class RobotProcessor(ModelHubMixin):
# Instantiate the step with its config
try:
step_instance: ProcessorStep = step_class(**step_entry.get("config", {}))
saved_cfg = step_entry.get("config", {})
step_overrides = overrides.get(step_key, {})
merged_cfg = {**saved_cfg, **step_overrides}
step_instance: ProcessorStep = step_class(**merged_cfg)
# Track which override keys were used
if step_key in override_keys:
override_keys.discard(step_key)
except Exception as e:
step_name = step_entry.get("registry_name", step_entry.get("class", "Unknown"))
raise ValueError(
@@ -499,6 +561,23 @@ class RobotProcessor(ModelHubMixin):
steps.append(step_instance)
# Check for unused override keys
if override_keys:
available_keys = []
for step_entry in config["steps"]:
if "registry_name" in step_entry:
available_keys.append(step_entry["registry_name"])
else:
full_class_path = step_entry["class"]
class_name = full_class_path.rsplit(".", 1)[1]
available_keys.append(class_name)
raise KeyError(
f"Override keys {list(override_keys)} do not match any step in the saved configuration. "
f"Available step keys: {available_keys}. "
f"Make sure override keys match exact step class names or registry names."
)
return cls(steps, config.get("name", "RobotProcessor"), config.get("seed"))
def __len__(self) -> int: