diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index 901a2672a..c56197dff 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -27,7 +27,6 @@ from typing import Any, Generic, TypeAlias, TypedDict, TypeVar, cast import torch from huggingface_hub import ModelHubMixin, hf_hub_download -from huggingface_hub.errors import HfHubHTTPError from safetensors.torch import load_file, save_file from lerobot.configs.types import PolicyFeature @@ -429,7 +428,7 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TOutput]): (e.g., "username/processor-name"). config_filename: Optional specific config filename to load. If not provided, will: - For local paths: look for any .json file in the directory (error if multiple found) - - For HF Hub: try common names ("processor.json", "preprocessor.json", "postprocessor.json") + - For HF Hub: REQUIRED - you must specify the exact config filename overrides: Optional dictionary mapping step names to configuration overrides. Keys must match exact step class names (for unregistered steps) or registry names (for registered steps). Values are dictionaries containing parameter overrides @@ -455,10 +454,10 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TOutput]): processor = DataProcessorPipeline.from_pretrained("path/to/processor") ``` - Loading specific config file: + Loading from HF Hub (config_filename required): ```python processor = DataProcessorPipeline.from_pretrained( - "username/multi-processor-repo", config_filename="preprocessor.json" + "username/processor-repo", config_filename="processor.json" ) ``` @@ -486,7 +485,19 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TOutput]): # Use the local variable name 'source' for clarity source = str(pretrained_model_name_or_path) - if Path(source).is_dir(): + # Check if it's a local path (either exists or looks like a filesystem path) + # Hub repositories are typically in the format "username/repo-name" (exactly one slash) + # Local paths are absolute paths, relative paths, or have more complex path structure + is_local_path = ( + Path(source).is_dir() + or Path(source).is_absolute() + or source.startswith("./") + or source.startswith("../") + or source.count("/") > 1 # More than one slash suggests local path, not Hub repo + or "\\" in source # Windows-style paths are definitely local + ) + + if is_local_path: # Local path - use it directly base_path = Path(source) @@ -505,57 +516,26 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TOutput]): with open(base_path / config_filename) as file_pointer: loaded_config: dict[str, Any] = json.load(file_pointer) else: - # Hugging Face Hub - download all required files + # Hugging Face Hub - download specific config file if config_filename is None: - # Try common config names - common_names = [ - "robot_processor.json", - "robot_preprocessor.json", - "robot_postprocessor.json", - ] - config_path = None - for name in common_names: - try: - config_path = hf_hub_download( - source, - name, - repo_type="model", - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - token=token, - cache_dir=cache_dir, - local_files_only=local_files_only, - revision=revision, - ) - config_filename = name - break - except (FileNotFoundError, OSError, HfHubHTTPError): - # FileNotFoundError: local file issues - # OSError: network/system errors - # HfHubHTTPError: file not found on Hub (404) or other HTTP errors - continue - - if config_path is None: - raise FileNotFoundError( - f"No processor configuration file found in {source}. " - f"Tried: {common_names}. Please specify the config_filename parameter." - ) - else: - # Download specific config file - config_path = hf_hub_download( - source, - config_filename, - repo_type="model", - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - token=token, - cache_dir=cache_dir, - local_files_only=local_files_only, - revision=revision, + raise ValueError( + f"For Hugging Face Hub repositories ({source}), you must specify the config_filename parameter. " + f"Example: DataProcessorPipeline.from_pretrained('{source}', config_filename='processor.json')" ) + config_path = hf_hub_download( + source, + config_filename, + repo_type="model", + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + ) + with open(config_path) as file_pointer: loaded_config = json.load(file_pointer) diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index 2bc7991fc..d7ab7d6a0 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -1714,16 +1714,26 @@ def test_override_with_device_strings(): def test_from_pretrained_nonexistent_path(): """Test error handling when loading from non-existent sources.""" - from huggingface_hub.errors import HfHubHTTPError, HFValidationError + from huggingface_hub.errors import HfHubHTTPError - # Test with an invalid repo ID (too many slashes) - caught by HF validation - with pytest.raises(HFValidationError): + # Test with an invalid local path - should raise FileNotFoundError + with pytest.raises(FileNotFoundError): DataProcessorPipeline.from_pretrained("/path/that/does/not/exist") - # Test with a non-existent but valid Hub repo format - with pytest.raises((FileNotFoundError, HfHubHTTPError)): + # Test with a Hub repo format that would be a local path (too many slashes) + with pytest.raises(FileNotFoundError): + DataProcessorPipeline.from_pretrained("user/repo/extra/path") + + # Test with a non-existent but valid Hub repo format (now requires config_filename) + with pytest.raises(ValueError, match="you must specify the config_filename parameter"): DataProcessorPipeline.from_pretrained("nonexistent-user/nonexistent-repo") + # Test with a non-existent Hub repo when config_filename is provided + with pytest.raises((FileNotFoundError, HfHubHTTPError)): + DataProcessorPipeline.from_pretrained( + "nonexistent-user/nonexistent-repo", config_filename="processor.json" + ) + # Test with a local directory that exists but has no config files with tempfile.TemporaryDirectory() as tmp_dir: with pytest.raises(FileNotFoundError, match="No .json configuration files found"):