From e5ade5565d4dc2b9b3f82fd822383a2a9cc11f06 Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Thu, 7 Aug 2025 16:13:34 +0200 Subject: [PATCH] Integrate pipeline and add phone teleop (#1681) * Add normalization processor and related components - Introduced `NormalizationProcessor` to handle both observation normalization and action unnormalization. - Added `ObservationNormalizer` and `ActionUnnormalizer` classes for specific normalization tasks. - Updated `__init__.py` to include the new `NormalizationProcessor` in the module exports. - Enhanced `ObservationProcessor` with registration in the `ProcessorStepRegistry` for better modularity. - Created `RenameProcessor` for renaming keys in observations, improving flexibility in data processing. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Enhance processing architecture with new components - Added `RenameProcessor` to facilitate key renaming in observations, improving data handling flexibility. - Updated `__init__.py` to include `RenameProcessor` in module exports. - Refactored `NormalizationProcessor` and `ObservationNormalizer` to use `rsplit` for better key handling. - Introduced comprehensive tests for `NormalizationProcessor` and `RenameProcessor` to ensure functionality and robustness. * chore (docs): add docstring for processor * fix (test): test factory * fix(test): policies * Update tests/processor/test_observation_processor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Adil Zouitine * chore(test): add suggestion made by copilot regarding numpy test * fix(test): import issue * Refactor normalization components and update tests - Renamed `ObservationNormalizer` to `NormalizerProcessor` and `ActionUnnormalizer` to `UnnormalizerProcessor` for clarity. - Consolidated normalization logic for both observations and actions into `NormalizerProcessor` and `UnnormalizerProcessor`. - Updated tests to reflect the new class names and ensure proper functionality of normalization and unnormalization processes. - Enhanced handling of missing statistics in normalization processes. * chore (docstrin):Improve docstring for NormalizerProcessor * feat (device processor): Implement device processor * chore (batch handling): Enhance processing components with batch conversion utilities * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix(test): linting issue * chore (output format): improves output format * chore (type): add typing for multiprocess envs * feat (overrides): Implement support for loading processors with parameter overrides - Added the ability to provide non-serializable objects when loading processors from saved configurations using the `overrides` parameter. - Enhanced error handling for invalid override keys and instantiation errors. - Updated documentation and examples to illustrate the usage of overrides for both registered and unregistered steps. - Added comprehensive tests to validate the new functionality and ensure backward compatibility. * chore(normalization): addressing comments from copilot * chore(learner): nit comment from copilot * feat(pipeline): Enhance step_through method to support both tuple and dict inputs * refactor(pipeline): Simplify observation and padding data handling in batch transitions * Apply suggestions from code review Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Signed-off-by: Adil Zouitine * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Introduce ComplementaryDataProcessor for handling complementary data in transitions * fix(ci): temporary fix on dataset deps version * feat(processors): Introduce processors for various policy types - Added `make_processor` function to create processor instances for different policy types, including `tdmpc`, `diffusion`, `act`, `vqbet`, `pi0`, `pi0fast`, `sac`, and `reward_classifier`. - Implemented corresponding processor files for each policy type, encapsulating normalization and unnormalization steps. - Updated existing policies to remove direct normalization dependencies, enhancing modularity and clarity. - Enhanced test coverage to validate the integration of new processors with existing policy configurations. * refactor(learner): Remove normalization from cached image features retrieval - Simplified the retrieval of observation features by removing the normalization step from the `get_cached_image_features` method calls. - This change enhances clarity and aligns with the recent updates to policy processors. * refactor(policies): Remove unnormalization step from action predictions - Eliminated the unnormalization of actions in both `TDMPCPolicy` and `VQBeTPolicy` classes to streamline action prediction. - This change improves code clarity and aligns with recent updates to policy processors. * feat(train): Integrate preprocessor into training pipeline * refactor(train): Update preprocessor initialization to include dataset statistics * refactor(policies): Enhance processor creation and add NaN detection hook * refactor(train): Update memory pinning logic for mps compatibility * feat: initial commit phone teleop * ugly delta control * use quaternion * Refactor observation preprocessing to use a modular pipeline system - Introduced `RobotPipeline` and `ObservationProcessor` for handling observation transformations. - Updated `preprocess_observation` to maintain backward compatibility while leveraging the new pipeline. - Added tests for the new processing components and ensured they match the original functionality. - Removed hardcoded logic in favor of a more flexible, composable architecture. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Refactor observation processing and improve modularity - Updated `ObservationProcessor` to enhance the modular design for processing observations. - Cleaned up imports and improved code readability by removing unnecessary lines and comments. - Ensured backward compatibility while integrating new processing components. - Added tests to validate the functionality of the updated processing architecture. * Remove redundant tests for None observation and serialization methods in `test_observation_processor.py` to streamline the test suite and improve maintainability. * Refactor processing architecture to use RobotProcessor - Replaced instances of RobotPipeline with RobotProcessor across the codebase for improved modularity and clarity. - Introduced ProcessorStepRegistry for better management of processing steps. - Updated relevant documentation and tests to reflect the new processing structure. - Enhanced the save/load functionality to support the new processor design. - Added a model card template for RobotProcessor to facilitate sharing and documentation. * Add RobotProcessor tutorial to documentation - Introduced a new tutorial on using RobotProcessor for preprocessing robot data. - Added a section in the table of contents for easy navigation to the new tutorial. - The tutorial covers key concepts, real-world scenarios, and practical examples for effective use of the RobotProcessor pipeline. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add normalization processor and related components - Introduced `NormalizationProcessor` to handle both observation normalization and action unnormalization. - Added `ObservationNormalizer` and `ActionUnnormalizer` classes for specific normalization tasks. - Updated `__init__.py` to include the new `NormalizationProcessor` in the module exports. - Enhanced `ObservationProcessor` with registration in the `ProcessorStepRegistry` for better modularity. - Created `RenameProcessor` for renaming keys in observations, improving flexibility in data processing. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Enhance processing architecture with new components - Added `RenameProcessor` to facilitate key renaming in observations, improving data handling flexibility. - Updated `__init__.py` to include `RenameProcessor` in module exports. - Refactored `NormalizationProcessor` and `ObservationNormalizer` to use `rsplit` for better key handling. - Introduced comprehensive tests for `NormalizationProcessor` and `RenameProcessor` to ensure functionality and robustness. * chore (docs): add docstring for processor * fix (test): test factory * fix(test): policies * Update tests/processor/test_observation_processor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Adil Zouitine * chore(test): add suggestion made by copilot regarding numpy test * fix(test): import issue * Refactor normalization components and update tests - Renamed `ObservationNormalizer` to `NormalizerProcessor` and `ActionUnnormalizer` to `UnnormalizerProcessor` for clarity. - Consolidated normalization logic for both observations and actions into `NormalizerProcessor` and `UnnormalizerProcessor`. - Updated tests to reflect the new class names and ensure proper functionality of normalization and unnormalization processes. - Enhanced handling of missing statistics in normalization processes. * chore (docstrin):Improve docstring for NormalizerProcessor * feat (device processor): Implement device processor * chore (batch handling): Enhance processing components with batch conversion utilities * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix(test): linting issue * chore (output format): improves output format * chore (type): add typing for multiprocess envs * feat (overrides): Implement support for loading processors with parameter overrides - Added the ability to provide non-serializable objects when loading processors from saved configurations using the `overrides` parameter. - Enhanced error handling for invalid override keys and instantiation errors. - Updated documentation and examples to illustrate the usage of overrides for both registered and unregistered steps. - Added comprehensive tests to validate the new functionality and ensure backward compatibility. * chore(normalization): addressing comments from copilot * chore(learner): nit comment from copilot * feat(pipeline): Enhance step_through method to support both tuple and dict inputs * refactor(pipeline): Simplify observation and padding data handling in batch transitions * Apply suggestions from code review Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Signed-off-by: Adil Zouitine * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Introduce ComplementaryDataProcessor for handling complementary data in transitions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Transition from tuple to dictionary format for EnvTransition - Updated the EnvTransition structure to use a dictionary format instead of a tuple, enhancing readability and maintainability. - Replaced instances of TransitionIndex with TransitionKey for accessing transition components. - Adjusted related processing functions and tests to accommodate the new dictionary format, ensuring consistent handling of transitions across the codebase. * refactor(observation_processor): Improve observation processing by using constants and simplifying pixel handling - Introduced constants for observation keys to enhance readability. - Streamlined the handling of the "pixels" key by copying observations first and processing images more clearly. - Updated the environment state and agent position assignments to use the new constants, improving maintainability. * feat(pipeline): Add hook unregistration functionality and enhance documentation - Implemented methods to unregister before, after, and reset hooks in the RobotProcessor class, allowing for more flexible hook management. - Enhanced documentation to clarify hook execution semantics and the implications of modifying transitions within hooks. - Added comprehensive tests to verify the correct behavior of hook registration and unregistration, including error handling for non-existent hooks. * refactor(pipeline): Clarify hook behavior and improve documentation - Updated the RobotProcessor class to ensure hooks are strictly for observation and do not modify transitions, enhancing clarity and maintainability. - Refactored hook registration methods to reflect the new behavior, ensuring they accept only functions that do not return modified transitions. - Enhanced documentation to clearly outline the purpose of hooks and their execution semantics. - Added tests to verify that hooks are not executed during the step_through method while ensuring they function correctly during the __call__ method. * feat(pipeline): Add __repr__ method to RobotProcessor for improved readability - Implemented a __repr__ method in the RobotProcessor class to provide a clear string representation of the processor, including step names and optional parameters like name and seed. - Added comprehensive tests to validate the __repr__ output for various scenarios, including empty processors, single and multiple steps, custom names, and seed values. - Ensured that the representation handles long lists of steps with truncation for better readability. * chore(pipeline): Move _CFG_NAME along other class member * refactor(pipeline): Utilize get_safe_torch_device for device assignment - Replaced direct torch.device instantiation with get_safe_torch_device to ensure safe device handling. - This change enhances code readability and maintains consistency in device management across the RobotProcessor class. * refactor(pipeline): Enhance state filename generation and profiling method - Updated state filename generation to use the registry name when available, improving clarity in saved files. - Modified the profile_steps method to include a warmup_runs parameter, allowing for more controlled performance profiling. - Ensured consistent conditions during profiling by deep copying transitions for each run, enhancing accuracy in timing results. * chore(doc): address pip install commant lerobot that not exist yet * feat(pipeline): Enhance configuration filename handling and state file naming - Introduced support for custom configuration filenames in the `save_pretrained` method, allowing users to specify a filename instead of the default. - Improved state file naming to include step indices, preventing conflicts when multiple processors of the same type are saved. - Added automatic detection for configuration files when loading from a directory, with error handling for multiple files. - Updated tests to validate new features, including custom filenames and automatic config detection. * refactor(pipeline): Improve state file naming conventions for clarity and uniqueness - Enhanced state file naming to include the processor's sanitized name, ensuring uniqueness when multiple processors are saved in the same directory. - Updated tests to reflect changes in state file naming, verifying that filenames now include the processor name and step indices to prevent conflicts. - Added a new test to validate state file naming when using multiple processors, ensuring distinct filenames for each processor's state files. * docs(pipeline): Add clarification for repo name sanitization process * feat(processors): Introduce processors for various policy types - Added `make_processor` function to create processor instances for different policy types, including `tdmpc`, `diffusion`, `act`, `vqbet`, `pi0`, `pi0fast`, `sac`, and `reward_classifier`. - Implemented corresponding processor files for each policy type, encapsulating normalization and unnormalization steps. - Updated existing policies to remove direct normalization dependencies, enhancing modularity and clarity. - Enhanced test coverage to validate the integration of new processors with existing policy configurations. * refactor(learner): Remove normalization from cached image features retrieval - Simplified the retrieval of observation features by removing the normalization step from the `get_cached_image_features` method calls. - This change enhances clarity and aligns with the recent updates to policy processors. * refactor(policies): Remove unnormalization step from action predictions - Eliminated the unnormalization of actions in both `TDMPCPolicy` and `VQBeTPolicy` classes to streamline action prediction. - This change improves code clarity and aligns with recent updates to policy processors. * feat(train): Integrate preprocessor into training pipeline * refactor(train): Update preprocessor initialization to include dataset statistics * refactor(policies): Enhance processor creation and add NaN detection hook * feat(record): Integrate RobotProcessor into recording loop and update policy handling - Added support for RobotProcessor in the record_loop function to enhance data processing capabilities. - Updated the logic to reset both policy and processor when provided, ensuring proper state management. - Modified action prediction to utilize the processor, improving the overall functionality of the recording process. - Adjusted the save_checkpoint function to include preprocessor state saving, enhancing checkpointing capabilities. * feat(migration): Add script for migrating policy models with normalization layers * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(migrate): Enhance migration script to create preprocessor and postprocessor for policy models - Updated the migration script to generate both a preprocessor and a postprocessor, improving the handling of normalization for training and inference. - Added functionality to convert features to PolicyFeature objects, ensuring compatibility with the new processor architecture. - Refined the extraction and removal of normalization statistics and layers, streamlining the migration process. - Improved error handling for missing mandatory configuration fields during model instantiation. * feat(migrate): Add model card generation and saving to migration script - Implemented functionality to generate and save a model card for the migrated model, including metadata such as dataset repository ID, license, and tags. - Enhanced the script to push the model card to the hub if requested, improving model documentation and accessibility. - Refactored the saving process to ensure the model card is saved locally and uploaded correctly when pushing to the hub. * feat(processor): Introduce ToBatchProcessor for handling observation batching - Added ToBatchProcessor to ensure observations have proper batch dimensions for model processing. - Implemented functionality to add batch dimensions to state and image observations as needed. - Created comprehensive unit tests to validate the processor's behavior with various tensor dimensions and types. - Ensured compatibility with existing transition keys and maintained the integrity of non-observation data. * feat(processors): Add ToBatchProcessor to multiple policy processors - Integrated ToBatchProcessor into various policy processors to handle observation batching. - Updated make functions for act, diffusion, pi0, pi0fast, sac, smolvla, tdmpc, and vqbet processors to include the new batching functionality. - Ensured consistency across all processor implementations for improved data handling. * refactor(factory): Remove unused imports and NaN detection hook from processor creation * feat(batch_processor): Enhance ToBatchProcessor to handle action batching - Updated ToBatchProcessor to add batch dimensions to actions in addition to observations. - Implemented separate methods for processing observations and actions, improving code readability. - Added comprehensive unit tests to validate action batching functionality across various tensor dimensions and types. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(factory): Enhance make_processor to support preprocessor and postprocessor configuration - Introduced ProcessorConfigKwargs TypedDict for better type safety in processor configuration. - Updated make_processor to accept preprocessor and postprocessor configuration filenames, improving flexibility in processor instantiation. - Refactored the loading of pretrained processors to utilize the new configuration options. * refactor(factory): Clean up imports in factory.py - Removed unused import of IdentityProcessor to streamline the code. * feat(migrate): Extend load_model_from_hub to include train configuration - Updated load_model_from_hub to return the train configuration alongside the model state_dict and config. - Modified main function to handle the additional train configuration when loading models from both the hub and local paths. - Adjusted dataset_repo_id extraction to utilize the train configuration for improved accuracy. * refactor(record): Rename processor parameters and update processing logic - Renamed `processor` to `preprocessor` and added `postprocessor` parameter for clarity. - Updated the `record_loop` and `predict_action` functions to utilize the new preprocessor and postprocessor, enhancing the processing flow. - Ensured compatibility with existing functionality while improving code readability. * feat(batch_processor): Add task field processing to ToBatchProcessor - Enhanced ToBatchProcessor to wrap string tasks in a list, adding batch dimensions for compatibility with model inference. - Implemented a new method for processing complementary data, ensuring that task values are correctly handled as either strings or lists of strings. - Added comprehensive unit tests to validate task processing, including edge cases and in-place mutation of complementary data. * feat(normalization): Implement IDENTITY mode for normalization and unnormalization - Enhanced NormalizerProcessor and UnnormalizerProcessor to support IDENTITY mode, allowing features to bypass normalization when specified. - Updated processing logic to check normalization modes and handle missing statistics gracefully. - Added comprehensive unit tests to validate IDENTITY mode functionality for both observations and actions, ensuring correct behavior across various scenarios. - Improved error handling for unsupported normalization modes. * fix(rebase): remove residual normalization layer: * refactor(diffusion): remove normalization layer from input processing * Add debug + calib * cleanup * Add pipeline * fix int * Add record example * nit * Add feature contract to pipelinestep and pipeline * Add tests * Add processor tests * PR feedback * encorperate pr feedback * type in doc * oops * cleaned up steps and integrated pipeline with feature_contract * refactor steps and robot to pipeline * cleanup pipeline * cleanup code further * make it run * feat(processors): Introduce processors for various policy types - Added `make_processor` function to create processor instances for different policy types, including `tdmpc`, `diffusion`, `act`, `vqbet`, `pi0`, `pi0fast`, `sac`, and `reward_classifier`. - Implemented corresponding processor files for each policy type, encapsulating normalization and unnormalization steps. - Updated existing policies to remove direct normalization dependencies, enhancing modularity and clarity. - Enhanced test coverage to validate the integration of new processors with existing policy configurations. * refactor(learner): Remove normalization from cached image features retrieval - Simplified the retrieval of observation features by removing the normalization step from the `get_cached_image_features` method calls. - This change enhances clarity and aligns with the recent updates to policy processors. * refactor(policies): Remove unnormalization step from action predictions - Eliminated the unnormalization of actions in both `TDMPCPolicy` and `VQBeTPolicy` classes to streamline action prediction. - This change improves code clarity and aligns with recent updates to policy processors. * feat(train): Integrate preprocessor into training pipeline * refactor(train): Update preprocessor initialization to include dataset statistics * refactor(policies): Enhance processor creation and add NaN detection hook * feat(record): Integrate RobotProcessor into recording loop and update policy handling - Added support for RobotProcessor in the record_loop function to enhance data processing capabilities. - Updated the logic to reset both policy and processor when provided, ensuring proper state management. - Modified action prediction to utilize the processor, improving the overall functionality of the recording process. - Adjusted the save_checkpoint function to include preprocessor state saving, enhancing checkpointing capabilities. * feat(migration): Add script for migrating policy models with normalization layers * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(migrate): Enhance migration script to create preprocessor and postprocessor for policy models - Updated the migration script to generate both a preprocessor and a postprocessor, improving the handling of normalization for training and inference. - Added functionality to convert features to PolicyFeature objects, ensuring compatibility with the new processor architecture. - Refined the extraction and removal of normalization statistics and layers, streamlining the migration process. - Improved error handling for missing mandatory configuration fields during model instantiation. * feat(migrate): Add model card generation and saving to migration script - Implemented functionality to generate and save a model card for the migrated model, including metadata such as dataset repository ID, license, and tags. - Enhanced the script to push the model card to the hub if requested, improving model documentation and accessibility. - Refactored the saving process to ensure the model card is saved locally and uploaded correctly when pushing to the hub. * feat(processor): Introduce ToBatchProcessor for handling observation batching - Added ToBatchProcessor to ensure observations have proper batch dimensions for model processing. - Implemented functionality to add batch dimensions to state and image observations as needed. - Created comprehensive unit tests to validate the processor's behavior with various tensor dimensions and types. - Ensured compatibility with existing transition keys and maintained the integrity of non-observation data. * feat(processors): Add ToBatchProcessor to multiple policy processors - Integrated ToBatchProcessor into various policy processors to handle observation batching. - Updated make functions for act, diffusion, pi0, pi0fast, sac, smolvla, tdmpc, and vqbet processors to include the new batching functionality. - Ensured consistency across all processor implementations for improved data handling. * refactor(factory): Remove unused imports and NaN detection hook from processor creation * feat(batch_processor): Enhance ToBatchProcessor to handle action batching - Updated ToBatchProcessor to add batch dimensions to actions in addition to observations. - Implemented separate methods for processing observations and actions, improving code readability. - Added comprehensive unit tests to validate action batching functionality across various tensor dimensions and types. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(factory): Enhance make_processor to support preprocessor and postprocessor configuration - Introduced ProcessorConfigKwargs TypedDict for better type safety in processor configuration. - Updated make_processor to accept preprocessor and postprocessor configuration filenames, improving flexibility in processor instantiation. - Refactored the loading of pretrained processors to utilize the new configuration options. * refactor(factory): Clean up imports in factory.py - Removed unused import of IdentityProcessor to streamline the code. * feat(migrate): Extend load_model_from_hub to include train configuration - Updated load_model_from_hub to return the train configuration alongside the model state_dict and config. - Modified main function to handle the additional train configuration when loading models from both the hub and local paths. - Adjusted dataset_repo_id extraction to utilize the train configuration for improved accuracy. * refactor(record): Rename processor parameters and update processing logic - Renamed `processor` to `preprocessor` and added `postprocessor` parameter for clarity. - Updated the `record_loop` and `predict_action` functions to utilize the new preprocessor and postprocessor, enhancing the processing flow. - Ensured compatibility with existing functionality while improving code readability. * feat(batch_processor): Add task field processing to ToBatchProcessor - Enhanced ToBatchProcessor to wrap string tasks in a list, adding batch dimensions for compatibility with model inference. - Implemented a new method for processing complementary data, ensuring that task values are correctly handled as either strings or lists of strings. - Added comprehensive unit tests to validate task processing, including edge cases and in-place mutation of complementary data. * feat(normalization): Implement IDENTITY mode for normalization and unnormalization - Enhanced NormalizerProcessor and UnnormalizerProcessor to support IDENTITY mode, allowing features to bypass normalization when specified. - Updated processing logic to check normalization modes and handle missing statistics gracefully. - Added comprehensive unit tests to validate IDENTITY mode functionality for both observations and actions, ensuring correct behavior across various scenarios. - Improved error handling for unsupported normalization modes. * fix(rebase): remove residual normalization layer: * refactor(diffusion): remove normalization layer from input processing * refactor(normalization): Remove unused state dict transformation methods and streamline imports - Eliminated the _transform_state_dict_keys and _load_as_safetensor methods from PI0Policy, simplifying the model loading process. - Cleaned up imports in modeling_pi0.py by removing log_model_loading_keys and init_logging. - Updated TDMPCPolicy and VQBeTPolicy to handle action removal from batches during offline evaluation. - Introduced hotswap_stats function in normalize_processor.py to update normalization statistics dynamically, with corresponding tests to ensure functionality. * refactor(normalization): Clean up imports in normalize_processor.py * feat(batch_processor): Add feature_contract method to ToBatchProcessor - Introduced feature_contract method that returns features without modification, maintaining the no-op behavior of the processor. - This addition enhances the flexibility of the ToBatchProcessor for future feature processing needs. * fix(dependencies): Update transformers dependency constraint to allow only versions up to 4.52.0 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(tokenizer): Introduce TokenizerProcessor for text tokenization - Added TokenizerProcessor class to handle tokenization of task strings using Hugging Face's AutoTokenizer. - Supports both string and list inputs, with customizable parameters for task key, output key, and tokenization settings. - Implemented comprehensive unit tests to validate functionality, including handling of various input scenarios and integration with RobotProcessor. - Updated types.py to include LANGUAGE feature type and modified __init__.py to register the new processor. * feat(language): Enhance language processing in TokenizerProcessor - Added OBS_LANGUAGE constant to define the observation language key. - Updated TokenizerProcessor to store tokenized task data in the observation dictionary, ensuring compatibility with the new language feature. - Introduced Pi0NewLineProcessor to append newlines to tasks for proper tokenization. - Modified tests to validate the integration of language tokens and attention masks in the observation structure. * feat(tokenizer): Add padding configuration to TokenizerProcessor - Introduced `padding_side` parameter to the TokenizerProcessor for customizable padding direction. - Updated the `make_pi0_processor` function to include the new padding configuration. - Enhanced unit tests to validate the functionality of the `padding_side` parameter in various scenarios. * feat(processor): Add state management methods to Pi0NewLineProcessor * feat(normalization): Track normalization and unnormalization info in complementary data - Updated NormalizerProcessor and UnnormalizerProcessor to accept additional parameters for tracking normalization modes. - Enhanced the __call__ methods to store normalization and unnormalization information in the complementary data of transitions. - Added unit tests to verify the correct tracking of normalization info, including scenarios with missing stats and selective normalization keys. * feat(factory): Add preprocessor and postprocessor overrides to ProcessorConfigKwargs - Updated ProcessorConfigKwargs to include optional overrides for preprocessor and postprocessor configurations. - Enhanced the make_processor function to utilize the new overrides, allowing for more flexible processor initialization. * feat(processors): Integrate RenameProcessor into various processor configurations - Added RenameProcessor to the input steps of multiple processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Consolidated normalization features from input and output into a single NormalizerProcessor for improved efficiency. - Updated the input steps to ensure compatibility with the new RenameProcessor integration. * Do some todos and cleanup * change feature_contract to dataset_features * use one method for conversion pipeline output to add_frame dict and use base processors where possible * Add back in and use record_loop * update todo * rename to_dataset_frame * feat(smolvla): Refactor language processing and introduce new line processor (#1658) - Removed the prepare_language method and directly accessed language tokens and masks from the batch using the OBS_LANGUAGE constant. - Added SmolVLANewLineProcessor to ensure tasks end with a newline, enhancing tokenization compatibility. - Updated the make_smolvla_processor function to include the new line processor and tokenizer processor for improved input handling. * feat(processors): Integrate DeviceProcessor into multiple processor configurations - Added DeviceProcessor to the input and output steps of various processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_pi0fast_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Enhanced the DeviceProcessor class with state management methods and ensured compatibility with existing processor pipelines. - Introduced unit tests for DeviceProcessor to validate functionality across different scenarios, including CPU and CUDA operations. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix reference frame * refactor(pipeline): Remove to() method for device management - Eliminated the to() method from RobotProcessor, which was responsible for moving tensor states to specified devices. - Removed associated unit tests that validated the functionality of the to() method across various scenarios. - Streamlined the pipeline code by focusing on other device management strategies. * feat(processor): Enhance DeviceProcessor with float dtype conversion - Added support for optional float dtype conversion in DeviceProcessor, allowing tensors to be converted to specified floating-point types while preserving non-float types. - Implemented validation for float dtype input and updated the processor's configuration methods to include float dtype. - Refactored tensor processing logic to streamline device movement and dtype conversion. - Introduced comprehensive unit tests to validate the new float dtype functionality across various scenarios. * update data visualization * update teleop example * fix record bugs * Add replay * Not code * feature(pipeline): port tokenizer pipeline for VLA (#1645) * feat(tokenizer): Introduce TokenizerProcessor for text tokenization - Added TokenizerProcessor class to handle tokenization of task strings using Hugging Face's AutoTokenizer. - Supports both string and list inputs, with customizable parameters for task key, output key, and tokenization settings. - Implemented comprehensive unit tests to validate functionality, including handling of various input scenarios and integration with RobotProcessor. - Updated types.py to include LANGUAGE feature type and modified __init__.py to register the new processor. * feat(language): Enhance language processing in TokenizerProcessor - Added OBS_LANGUAGE constant to define the observation language key. - Updated TokenizerProcessor to store tokenized task data in the observation dictionary, ensuring compatibility with the new language feature. - Introduced Pi0NewLineProcessor to append newlines to tasks for proper tokenization. - Modified tests to validate the integration of language tokens and attention masks in the observation structure. * feat(tokenizer): Add padding configuration to TokenizerProcessor - Introduced `padding_side` parameter to the TokenizerProcessor for customizable padding direction. - Updated the `make_pi0_processor` function to include the new padding configuration. - Enhanced unit tests to validate the functionality of the `padding_side` parameter in various scenarios. * feat(processor): Add state management methods to Pi0NewLineProcessor * feat(normalization): Track normalization and unnormalization info in complementary data - Updated NormalizerProcessor and UnnormalizerProcessor to accept additional parameters for tracking normalization modes. - Enhanced the __call__ methods to store normalization and unnormalization information in the complementary data of transitions. - Added unit tests to verify the correct tracking of normalization info, including scenarios with missing stats and selective normalization keys. * feat(factory): Add preprocessor and postprocessor overrides to ProcessorConfigKwargs - Updated ProcessorConfigKwargs to include optional overrides for preprocessor and postprocessor configurations. - Enhanced the make_processor function to utilize the new overrides, allowing for more flexible processor initialization. * feat(processors): Integrate RenameProcessor into various processor configurations - Added RenameProcessor to the input steps of multiple processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Consolidated normalization features from input and output into a single NormalizerProcessor for improved efficiency. - Updated the input steps to ensure compatibility with the new RenameProcessor integration. * feat(smolvla): Refactor language processing and introduce new line processor (#1658) - Removed the prepare_language method and directly accessed language tokens and masks from the batch using the OBS_LANGUAGE constant. - Added SmolVLANewLineProcessor to ensure tasks end with a newline, enhancing tokenization compatibility. - Updated the make_smolvla_processor function to include the new line processor and tokenizer processor for improved input handling. * feture(policies): add device processor (#1659) * feat(processors): Integrate DeviceProcessor into multiple processor configurations - Added DeviceProcessor to the input and output steps of various processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_pi0fast_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Enhanced the DeviceProcessor class with state management methods and ensured compatibility with existing processor pipelines. - Introduced unit tests for DeviceProcessor to validate functionality across different scenarios, including CPU and CUDA operations. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Remove to() method for device management - Eliminated the to() method from RobotProcessor, which was responsible for moving tensor states to specified devices. - Removed associated unit tests that validated the functionality of the to() method across various scenarios. - Streamlined the pipeline code by focusing on other device management strategies. * feat(processor): Enhance DeviceProcessor with float dtype conversion - Added support for optional float dtype conversion in DeviceProcessor, allowing tensors to be converted to specified floating-point types while preserving non-float types. - Implemented validation for float dtype input and updated the processor's configuration methods to include float dtype. - Refactored tensor processing logic to streamline device movement and dtype conversion. - Introduced comprehensive unit tests to validate the new float dtype functionality across various scenarios. * feat(policies): Add new line processors and update module exports * feat(processor): Enhance batch and device processors to handle index and task_index fields - Added logic to ToBatchProcessor for unsqueezing 0D tensors for index and task_index fields, ensuring they are processed as 1D tensors. - Updated DeviceProcessor to process index and task_index fields in complementary data, preserving their tensor types and ensuring non-tensor fields remain unchanged. - Enhanced unit tests to validate the correct handling of index and task_index fields across various scenarios, including device compatibility and dtype preservation. * Add eval script * fix `q_curr` in InverseKinematicsEEToJoints to the IK solution * feat(processors): Introduce processors for various policy types - Added `make_processor` function to create processor instances for different policy types, including `tdmpc`, `diffusion`, `act`, `vqbet`, `pi0`, `pi0fast`, `sac`, and `reward_classifier`. - Implemented corresponding processor files for each policy type, encapsulating normalization and unnormalization steps. - Updated existing policies to remove direct normalization dependencies, enhancing modularity and clarity. - Enhanced test coverage to validate the integration of new processors with existing policy configurations. * refactor(learner): Remove normalization from cached image features retrieval - Simplified the retrieval of observation features by removing the normalization step from the `get_cached_image_features` method calls. - This change enhances clarity and aligns with the recent updates to policy processors. * refactor(policies): Remove unnormalization step from action predictions - Eliminated the unnormalization of actions in both `TDMPCPolicy` and `VQBeTPolicy` classes to streamline action prediction. - This change improves code clarity and aligns with recent updates to policy processors. * feat(train): Integrate preprocessor into training pipeline * refactor(train): Update preprocessor initialization to include dataset statistics * refactor(policies): Enhance processor creation and add NaN detection hook * feat(record): Integrate RobotProcessor into recording loop and update policy handling - Added support for RobotProcessor in the record_loop function to enhance data processing capabilities. - Updated the logic to reset both policy and processor when provided, ensuring proper state management. - Modified action prediction to utilize the processor, improving the overall functionality of the recording process. - Adjusted the save_checkpoint function to include preprocessor state saving, enhancing checkpointing capabilities. * feat(migration): Add script for migrating policy models with normalization layers * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(migrate): Enhance migration script to create preprocessor and postprocessor for policy models - Updated the migration script to generate both a preprocessor and a postprocessor, improving the handling of normalization for training and inference. - Added functionality to convert features to PolicyFeature objects, ensuring compatibility with the new processor architecture. - Refined the extraction and removal of normalization statistics and layers, streamlining the migration process. - Improved error handling for missing mandatory configuration fields during model instantiation. * feat(migrate): Add model card generation and saving to migration script - Implemented functionality to generate and save a model card for the migrated model, including metadata such as dataset repository ID, license, and tags. - Enhanced the script to push the model card to the hub if requested, improving model documentation and accessibility. - Refactored the saving process to ensure the model card is saved locally and uploaded correctly when pushing to the hub. * feat(processor): Introduce ToBatchProcessor for handling observation batching - Added ToBatchProcessor to ensure observations have proper batch dimensions for model processing. - Implemented functionality to add batch dimensions to state and image observations as needed. - Created comprehensive unit tests to validate the processor's behavior with various tensor dimensions and types. - Ensured compatibility with existing transition keys and maintained the integrity of non-observation data. * feat(processors): Add ToBatchProcessor to multiple policy processors - Integrated ToBatchProcessor into various policy processors to handle observation batching. - Updated make functions for act, diffusion, pi0, pi0fast, sac, smolvla, tdmpc, and vqbet processors to include the new batching functionality. - Ensured consistency across all processor implementations for improved data handling. * refactor(factory): Remove unused imports and NaN detection hook from processor creation * feat(batch_processor): Enhance ToBatchProcessor to handle action batching - Updated ToBatchProcessor to add batch dimensions to actions in addition to observations. - Implemented separate methods for processing observations and actions, improving code readability. - Added comprehensive unit tests to validate action batching functionality across various tensor dimensions and types. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(factory): Enhance make_processor to support preprocessor and postprocessor configuration - Introduced ProcessorConfigKwargs TypedDict for better type safety in processor configuration. - Updated make_processor to accept preprocessor and postprocessor configuration filenames, improving flexibility in processor instantiation. - Refactored the loading of pretrained processors to utilize the new configuration options. * refactor(factory): Clean up imports in factory.py - Removed unused import of IdentityProcessor to streamline the code. * feat(migrate): Extend load_model_from_hub to include train configuration - Updated load_model_from_hub to return the train configuration alongside the model state_dict and config. - Modified main function to handle the additional train configuration when loading models from both the hub and local paths. - Adjusted dataset_repo_id extraction to utilize the train configuration for improved accuracy. * refactor(record): Rename processor parameters and update processing logic - Renamed `processor` to `preprocessor` and added `postprocessor` parameter for clarity. - Updated the `record_loop` and `predict_action` functions to utilize the new preprocessor and postprocessor, enhancing the processing flow. - Ensured compatibility with existing functionality while improving code readability. * feat(batch_processor): Add task field processing to ToBatchProcessor - Enhanced ToBatchProcessor to wrap string tasks in a list, adding batch dimensions for compatibility with model inference. - Implemented a new method for processing complementary data, ensuring that task values are correctly handled as either strings or lists of strings. - Added comprehensive unit tests to validate task processing, including edge cases and in-place mutation of complementary data. * feat(normalization): Implement IDENTITY mode for normalization and unnormalization - Enhanced NormalizerProcessor and UnnormalizerProcessor to support IDENTITY mode, allowing features to bypass normalization when specified. - Updated processing logic to check normalization modes and handle missing statistics gracefully. - Added comprehensive unit tests to validate IDENTITY mode functionality for both observations and actions, ensuring correct behavior across various scenarios. - Improved error handling for unsupported normalization modes. * fix(rebase): remove residual normalization layer: * refactor(diffusion): remove normalization layer from input processing * refactor(normalization): Remove unused state dict transformation methods and streamline imports - Eliminated the _transform_state_dict_keys and _load_as_safetensor methods from PI0Policy, simplifying the model loading process. - Cleaned up imports in modeling_pi0.py by removing log_model_loading_keys and init_logging. - Updated TDMPCPolicy and VQBeTPolicy to handle action removal from batches during offline evaluation. - Introduced hotswap_stats function in normalize_processor.py to update normalization statistics dynamically, with corresponding tests to ensure functionality. * refactor(normalization): Clean up imports in normalize_processor.py * feat(batch_processor): Add feature_contract method to ToBatchProcessor - Introduced feature_contract method that returns features without modification, maintaining the no-op behavior of the processor. - This addition enhances the flexibility of the ToBatchProcessor for future feature processing needs. * fix(dependencies): Update transformers dependency constraint to allow only versions up to 4.52.0 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feature(pipeline): port tokenizer pipeline for VLA (#1645) * feat(tokenizer): Introduce TokenizerProcessor for text tokenization - Added TokenizerProcessor class to handle tokenization of task strings using Hugging Face's AutoTokenizer. - Supports both string and list inputs, with customizable parameters for task key, output key, and tokenization settings. - Implemented comprehensive unit tests to validate functionality, including handling of various input scenarios and integration with RobotProcessor. - Updated types.py to include LANGUAGE feature type and modified __init__.py to register the new processor. * feat(language): Enhance language processing in TokenizerProcessor - Added OBS_LANGUAGE constant to define the observation language key. - Updated TokenizerProcessor to store tokenized task data in the observation dictionary, ensuring compatibility with the new language feature. - Introduced Pi0NewLineProcessor to append newlines to tasks for proper tokenization. - Modified tests to validate the integration of language tokens and attention masks in the observation structure. * feat(tokenizer): Add padding configuration to TokenizerProcessor - Introduced `padding_side` parameter to the TokenizerProcessor for customizable padding direction. - Updated the `make_pi0_processor` function to include the new padding configuration. - Enhanced unit tests to validate the functionality of the `padding_side` parameter in various scenarios. * feat(processor): Add state management methods to Pi0NewLineProcessor * feat(normalization): Track normalization and unnormalization info in complementary data - Updated NormalizerProcessor and UnnormalizerProcessor to accept additional parameters for tracking normalization modes. - Enhanced the __call__ methods to store normalization and unnormalization information in the complementary data of transitions. - Added unit tests to verify the correct tracking of normalization info, including scenarios with missing stats and selective normalization keys. * feat(factory): Add preprocessor and postprocessor overrides to ProcessorConfigKwargs - Updated ProcessorConfigKwargs to include optional overrides for preprocessor and postprocessor configurations. - Enhanced the make_processor function to utilize the new overrides, allowing for more flexible processor initialization. * feat(processors): Integrate RenameProcessor into various processor configurations - Added RenameProcessor to the input steps of multiple processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Consolidated normalization features from input and output into a single NormalizerProcessor for improved efficiency. - Updated the input steps to ensure compatibility with the new RenameProcessor integration. * feat(smolvla): Refactor language processing and introduce new line processor (#1658) - Removed the prepare_language method and directly accessed language tokens and masks from the batch using the OBS_LANGUAGE constant. - Added SmolVLANewLineProcessor to ensure tasks end with a newline, enhancing tokenization compatibility. - Updated the make_smolvla_processor function to include the new line processor and tokenizer processor for improved input handling. * feture(policies): add device processor (#1659) * feat(processors): Integrate DeviceProcessor into multiple processor configurations - Added DeviceProcessor to the input and output steps of various processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_pi0fast_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Enhanced the DeviceProcessor class with state management methods and ensured compatibility with existing processor pipelines. - Introduced unit tests for DeviceProcessor to validate functionality across different scenarios, including CPU and CUDA operations. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Remove to() method for device management - Eliminated the to() method from RobotProcessor, which was responsible for moving tensor states to specified devices. - Removed associated unit tests that validated the functionality of the to() method across various scenarios. - Streamlined the pipeline code by focusing on other device management strategies. * feat(processor): Enhance DeviceProcessor with float dtype conversion - Added support for optional float dtype conversion in DeviceProcessor, allowing tensors to be converted to specified floating-point types while preserving non-float types. - Implemented validation for float dtype input and updated the processor's configuration methods to include float dtype. - Refactored tensor processing logic to streamline device movement and dtype conversion. - Introduced comprehensive unit tests to validate the new float dtype functionality across various scenarios. * feat(policies): Add new line processors and update module exports * feat(processor): Enhance batch and device processors to handle index and task_index fields - Added logic to ToBatchProcessor for unsqueezing 0D tensors for index and task_index fields, ensuring they are processed as 1D tensors. - Updated DeviceProcessor to process index and task_index fields in complementary data, preserving their tensor types and ensuring non-tensor fields remain unchanged. - Enhanced unit tests to validate the correct handling of index and task_index fields across various scenarios, including device compatibility and dtype preservation. * refactor(processors): Standardize processor naming conventions - Updated processor names across various files to use a consistent "robot_preprocessor" and "robot_postprocessor" format. - Modified the make_processor functions in factory, act, diffusion, pi0, pi0fast, sac, smolvla, tdmpc, and vqbet to reflect the new naming scheme. - Enhanced the pipeline configuration to align with the updated processor names, improving clarity and maintainability. * refactor(factory): Update processor configuration and type hints - Changed return type of get_policy_class to type[PreTrainedPolicy] for improved type safety. - Enhanced make_processor function to utilize dataset_stats in processor creation for better flexibility. - Updated ProcessorConfigKwargs to include dataset_stats, allowing for more comprehensive processor configurations. - Streamlined processor initialization by removing unnecessary kwargs and ensuring clarity in processor type handling. * Fix eval and android gripper * add some tests * refactor(factory, pi0fast): Update processor function names and parameters - Renamed make_pi0_processor to make_pi0fast_processor for clarity and consistency. - Updated parameter names in the factory's make_processor function to use pretrained_model_name_or_path instead of source, enhancing readability and alignment with naming conventions. * fix(train.py) push postprocessor with preprocessor - Add preprocesser policy overrides for device and rename_map - Add rename_map to DatasetRecordConfig (record.py) * Cleanup pr * fix more git diff pr issues * add path as type in save_pretrained * small nit * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * rename test file * fix: make dataset_features/feature_contract is optional * fix tests * Encorperate pr feedback * clean up record.py * add ascii art, fix normal record * remove merge issues * fix merge * remove features * Add feedback PR * fix last 4 tests * remove features check * rename to transform_features * add transform_features * fix lekiwi eval and update eval api example --------- Signed-off-by: Adil Zouitine Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Adil Zouitine Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Michel Aractingi --- docs/source/il_robots.mdx | 15 +- examples/lekiwi/evaluate.py | 15 +- examples/lekiwi/teleoperate.py | 2 +- examples/phone_so100_eval.py | 158 +++++++ examples/phone_so100_record.py | 215 +++++++++ examples/phone_so100_replay.py | 106 +++++ examples/phone_so100_teleop.py | 109 +++++ pyproject.toml | 4 +- src/lerobot/datasets/pipeline_features.py | 94 ++++ src/lerobot/datasets/utils.py | 44 ++ src/lerobot/policies/pi0/processor_pi0.py | 4 +- .../policies/smolvla/processor_smolvla.py | 4 +- src/lerobot/processor/batch_processor.py | 4 +- src/lerobot/processor/converters.py | 225 +++++++++ src/lerobot/processor/device_processor.py | 2 +- src/lerobot/processor/normalize_processor.py | 4 +- .../processor/observation_processor.py | 3 +- src/lerobot/processor/pipeline.py | 45 +- src/lerobot/processor/rename_processor.py | 2 +- src/lerobot/processor/tokenizer_processor.py | 2 +- src/lerobot/record.py | 145 ++++-- src/lerobot/robots/so100_follower/__init__.py | 3 +- .../so100_follower/config_so100_follower.py | 32 -- .../robot_kinematic_processor.py | 447 ++++++++++++++++++ .../so100_follower_end_effector.py | 200 -------- src/lerobot/robots/utils.py | 1 + src/lerobot/teleoperate.py | 2 +- src/lerobot/teleoperators/phone/__init__.py | 18 + .../teleoperators/phone/config_phone.py | 36 ++ src/lerobot/teleoperators/phone/phone.py | 246 ++++++++++ .../teleoperators/phone/phone_processor.py | 87 ++++ src/lerobot/utils/visualization_utils.py | 101 +++- tests/datasets/test_dataset_utils.py | 132 ++++++ tests/datasets/test_utils.py | 55 --- tests/processor/test_converters.py | 196 ++++++++ tests/processor/test_device_processor.py | 6 +- tests/processor/test_normalize_processor.py | 17 +- tests/processor/test_observation_processor.py | 20 +- tests/processor/test_pipeline.py | 244 ++++++++-- tests/processor/test_rename_processor.py | 12 +- tests/processor/test_tokenizer_processor.py | 12 +- tests/utils/test_visualization_utils.py | 205 ++++++++ 42 files changed, 2819 insertions(+), 455 deletions(-) create mode 100644 examples/phone_so100_eval.py create mode 100644 examples/phone_so100_record.py create mode 100644 examples/phone_so100_replay.py create mode 100644 examples/phone_so100_teleop.py create mode 100644 src/lerobot/datasets/pipeline_features.py create mode 100644 src/lerobot/processor/converters.py create mode 100644 src/lerobot/robots/so100_follower/robot_kinematic_processor.py delete mode 100644 src/lerobot/robots/so100_follower/so100_follower_end_effector.py create mode 100644 src/lerobot/teleoperators/phone/__init__.py create mode 100644 src/lerobot/teleoperators/phone/config_phone.py create mode 100644 src/lerobot/teleoperators/phone/phone.py create mode 100644 src/lerobot/teleoperators/phone/phone_processor.py create mode 100644 tests/datasets/test_dataset_utils.py delete mode 100644 tests/datasets/test_utils.py create mode 100644 tests/processor/test_converters.py create mode 100644 tests/utils/test_visualization_utils.py diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx index ec5491b2a..f1c15a1d0 100644 --- a/docs/source/il_robots.mdx +++ b/docs/source/il_robots.mdx @@ -519,11 +519,14 @@ from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import _init_rerun from lerobot.record import record_loop +from lerobot.policies.factory import make_processor NUM_EPISODES = 5 FPS = 30 EPISODE_TIME_SEC = 60 TASK_DESCRIPTION = "My task description" +HF_MODEL_ID = "/" +HF_DATASET_ID = "/" # Create the robot configuration camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)} @@ -535,7 +538,7 @@ robot_config = SO100FollowerConfig( robot = SO100Follower(robot_config) # Initialize the policy -policy = ACTPolicy.from_pretrained("/") +policy = ACTPolicy.from_pretrained(HF_MODEL_ID) # Configure the dataset features action_features = hw_to_dataset_features(robot.action_features, "action") @@ -544,7 +547,7 @@ dataset_features = {**action_features, **obs_features} # Create the dataset dataset = LeRobotDataset.create( - repo_id="/eval_", + repo_id=HF_DATASET_ID, fps=FPS, features=dataset_features, robot_type=robot.name, @@ -559,6 +562,12 @@ _init_rerun(session_name="recording") # Connect the robot robot.connect() +preprocessor, postprocessor = make_processor( + policy_cfg=policy, + pretrained_path=HF_MODEL_ID, + dataset_stats=dataset.meta.stats, +) + for episode_idx in range(NUM_EPISODES): log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}") @@ -568,6 +577,8 @@ for episode_idx in range(NUM_EPISODES): events=events, fps=FPS, policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, dataset=dataset, control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, diff --git a/examples/lekiwi/evaluate.py b/examples/lekiwi/evaluate.py index 57fb62e10..564648329 100644 --- a/examples/lekiwi/evaluate.py +++ b/examples/lekiwi/evaluate.py @@ -1,6 +1,7 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import hw_to_dataset_features from lerobot.policies.act.modeling_act import ACTPolicy +from lerobot.policies.factory import make_processor from lerobot.record import record_loop from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig from lerobot.utils.control_utils import init_keyboard_listener @@ -11,12 +12,14 @@ NUM_EPISODES = 2 FPS = 30 EPISODE_TIME_SEC = 60 TASK_DESCRIPTION = "My task description" +HF_MODEL_ID = "/" +HF_DATASET_ID = "/" # Create the robot and teleoperator configurations robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi") robot = LeKiwiClient(robot_config) -policy = ACTPolicy.from_pretrained("/") +policy = ACTPolicy.from_pretrained(HF_MODEL_ID) # Configure the dataset features action_features = hw_to_dataset_features(robot.action_features, "action") @@ -25,7 +28,7 @@ dataset_features = {**action_features, **obs_features} # Create the dataset dataset = LeRobotDataset.create( - repo_id="/", + repo_id=HF_DATASET_ID, fps=FPS, features=dataset_features, robot_type=robot.name, @@ -43,6 +46,12 @@ listener, events = init_keyboard_listener() if not robot.is_connected: raise ValueError("Robot is not connected!") +preprocessor, postprocessor = make_processor( + policy_cfg=policy, + pretrained_path=HF_MODEL_ID, + dataset_stats=dataset.meta.stats, +) + recorded_episodes = 0 while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}") @@ -53,6 +62,8 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: events=events, fps=FPS, policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, dataset=dataset, control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, diff --git a/examples/lekiwi/teleoperate.py b/examples/lekiwi/teleoperate.py index 8358a2b93..45afca0cf 100644 --- a/examples/lekiwi/teleoperate.py +++ b/examples/lekiwi/teleoperate.py @@ -38,7 +38,7 @@ while True: keyboard_keys = keyboard.get_action() base_action = robot._from_keyboard_to_base_action(keyboard_keys) - log_rerun_data(observation, {**arm_action, **base_action}) + log_rerun_data(observation=observation, action={**arm_action, **base_action}) action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action diff --git a/examples/phone_so100_eval.py b/examples/phone_so100_eval.py new file mode 100644 index 000000000..e3a577de5 --- /dev/null +++ b/examples/phone_so100_eval.py @@ -0,0 +1,158 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features +from lerobot.datasets.utils import merge_features +from lerobot.model.kinematics import RobotKinematics +from lerobot.policies.act.modeling_act import ACTPolicy +from lerobot.policies.factory import make_processor +from lerobot.processor.converters import ( + to_output_robot_action, + to_transition_robot_observation, +) +from lerobot.processor.pipeline import RobotProcessor +from lerobot.record import record_loop +from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig +from lerobot.robots.so100_follower.robot_kinematic_processor import ( + AddRobotObservationAsComplimentaryData, + ForwardKinematicsJointsToEE, + InverseKinematicsEEToJoints, +) +from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.utils.control_utils import init_keyboard_listener +from lerobot.utils.utils import log_say +from lerobot.utils.visualization_utils import _init_rerun + +NUM_EPISODES = 5 +FPS = 30 +EPISODE_TIME_SEC = 60 +TASK_DESCRIPTION = "My task description" +HF_MODEL_ID = "/" +HF_DATASET_ID = "/" + +# Initialize the robot with degrees +camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)} +robot_config = SO100FollowerConfig( + port="/dev/tty.usbmodem58760434471", + id="my_awesome_follower_arm", + cameras=camera_config, + use_degrees=True, +) + +# Initialize the robot +robot = SO100Follower(robot_config) + +# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf +kinematics_solver = RobotKinematics( + urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=list(robot.bus.motors.keys()), +) + +# Build pipeline to convert ee pose action to joint action +robot_ee_to_joints = RobotProcessor( + steps=[ + AddRobotObservationAsComplimentaryData(robot=robot), + InverseKinematicsEEToJoints( + kinematics=kinematics_solver, + motor_names=list(robot.bus.motors.keys()), + initial_guess_current_joints=True, + ), + ], + to_transition=lambda tr: tr, + to_output=to_output_robot_action, +) + +# Build pipeline to convert joint observation to ee pose observation +robot_joints_to_ee_pose = RobotProcessor( + steps=[ + ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys())) + ], + to_transition=to_transition_robot_observation, + to_output=lambda tr: tr, +) + +# Build dataset action and gripper features +action_ee_and_gripper = aggregate_pipeline_dataset_features( + pipeline=robot_ee_to_joints, + initial_features={}, + use_videos=True, + patterns=["action.ee", "action.gripper.pos", "observation.state.gripper.pos"], +) # Get all ee action features + gripper pos action features + +# Build dataset observation features +obs_ee = aggregate_pipeline_dataset_features( + pipeline=robot_joints_to_ee_pose, + initial_features=robot.observation_features, + use_videos=True, + patterns=["observation.state.ee"], +) # Get all ee observation features + +dataset_features = merge_features(obs_ee, action_ee_and_gripper) + +print("All dataset features: ", dataset_features) + +# Create the dataset +dataset = LeRobotDataset.create( + repo_id=HF_DATASET_ID, + fps=FPS, + features=dataset_features, + robot_type=robot.name, + use_videos=True, + image_writer_threads=4, +) + +# Initialize the keyboard listener and rerun visualization +_, events = init_keyboard_listener() +_init_rerun(session_name="recording_phone") + +# Connect the robot and teleoperator +robot.connect() + +episode_idx = 0 + +policy = ACTPolicy.from_pretrained(HF_MODEL_ID) +preprocessor, postprocessor = make_processor( + policy_cfg=policy, + pretrained_path=HF_MODEL_ID, + dataset_stats=dataset.meta.stats, +) + +for episode_idx in range(NUM_EPISODES): + log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}") + + record_loop( + robot=robot, + events=events, + fps=FPS, + policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, + dataset=dataset, + control_time_s=EPISODE_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + robot_action_processor=robot_ee_to_joints, + robot_observation_processor=robot_joints_to_ee_pose, + ) + dataset.save_episode() + +# Clean up +log_say("Stop recording") +robot.disconnect() +dataset.push_to_hub() diff --git a/examples/phone_so100_record.py b/examples/phone_so100_record.py new file mode 100644 index 000000000..4ec3948ea --- /dev/null +++ b/examples/phone_so100_record.py @@ -0,0 +1,215 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features +from lerobot.datasets.utils import merge_features +from lerobot.model.kinematics import RobotKinematics +from lerobot.processor.converters import ( + to_output_robot_action, + to_transition_robot_observation, + to_transition_teleop_action, +) +from lerobot.processor.pipeline import RobotProcessor +from lerobot.record import record_loop +from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig +from lerobot.robots.so100_follower.robot_kinematic_processor import ( + AddRobotObservationAsComplimentaryData, + EEBoundsAndSafety, + EEReferenceAndDelta, + ForwardKinematicsJointsToEE, + GripperVelocityToJoint, + InverseKinematicsEEToJoints, +) +from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS +from lerobot.teleoperators.phone.phone import Phone +from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction +from lerobot.utils.control_utils import init_keyboard_listener +from lerobot.utils.utils import log_say +from lerobot.utils.visualization_utils import _init_rerun + +NUM_EPISODES = 10 +FPS = 30 +EPISODE_TIME_SEC = 60 +RESET_TIME_SEC = 30 +TASK_DESCRIPTION = "My task description" +HF_REPO_ID = "/" + +# Initialize the robot and teleoperator +camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)} +robot_config = SO100FollowerConfig( + port="/dev/tty.usbmodem58760434471", + id="my_awesome_follower_arm", + cameras=camera_config, + use_degrees=True, +) +teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID + +# Initialize the robot and teleoperator +robot = SO100Follower(robot_config) +phone = Phone(teleop_config) + +# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf +kinematics_solver = RobotKinematics( + urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=list(robot.bus.motors.keys()), +) + +# Build pipeline to convert phone action to ee pose action +phone_to_robot_ee_pose = RobotProcessor( + steps=[ + MapPhoneActionToRobotAction(platform=teleop_config.phone_os), + AddRobotObservationAsComplimentaryData(robot=robot), + EEReferenceAndDelta( + kinematics=kinematics_solver, + end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5}, + motor_names=list(robot.bus.motors.keys()), + ), + EEBoundsAndSafety( + end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, + max_ee_step_m=0.20, + max_ee_twist_step_rad=0.50, + ), + ], + to_transition=to_transition_teleop_action, + to_output=lambda tr: tr, +) + +# Build pipeline to convert ee pose action to joint action +robot_ee_to_joints = RobotProcessor( + steps=[ + InverseKinematicsEEToJoints( + kinematics=kinematics_solver, + motor_names=list(robot.bus.motors.keys()), + initial_guess_current_joints=True, + ), + GripperVelocityToJoint( + motor_names=list(robot.bus.motors.keys()), + speed_factor=20.0, + ), + ], + to_transition=lambda tr: tr, + to_output=to_output_robot_action, +) + +# Build pipeline to convert joint observation to ee pose observation +robot_joints_to_ee_pose = RobotProcessor( + steps=[ + ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys())) + ], + to_transition=to_transition_robot_observation, + to_output=lambda tr: tr, +) + +# Build dataset ee action features +action_ee = aggregate_pipeline_dataset_features( + pipeline=phone_to_robot_ee_pose, + initial_features=phone.action_features, + use_videos=True, + patterns=["action.ee"], +) + +# Get gripper pos action features +gripper = aggregate_pipeline_dataset_features( + pipeline=robot_ee_to_joints, + initial_features={}, + use_videos=True, + patterns=["action.gripper.pos", "observation.state.gripper.pos"], +) + +# Build dataset ee observation features +observation_ee = aggregate_pipeline_dataset_features( + pipeline=robot_joints_to_ee_pose, + initial_features=robot.observation_features, + use_videos=True, + patterns=["observation.state.ee"], +) + +dataset_features = merge_features(action_ee, gripper, observation_ee) + +print("All dataset features: ", dataset_features) + +# Create the dataset +dataset = LeRobotDataset.create( + repo_id=HF_REPO_ID, + fps=FPS, + features=dataset_features, + robot_type=robot.name, + use_videos=True, + image_writer_threads=4, +) + +# Initialize the keyboard listener and rerun visualization +_, events = init_keyboard_listener() +_init_rerun(session_name="recording_phone") + +# Connect the robot and teleoperator +robot.connect() +phone.connect() + +episode_idx = 0 +while episode_idx < NUM_EPISODES and not events["stop_recording"]: + log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}") + + record_loop( + robot=robot, + events=events, + fps=FPS, + teleop=phone, + dataset=dataset, + control_time_s=EPISODE_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=phone_to_robot_ee_pose, + robot_action_processor=robot_ee_to_joints, + robot_observation_processor=robot_joints_to_ee_pose, + ) + + # Reset the environment if not stopping or re-recording + if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]): + log_say("Reset the environment") + record_loop( + robot=robot, + events=events, + fps=FPS, + teleop=phone, + control_time_s=RESET_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=phone_to_robot_ee_pose, + robot_action_processor=robot_ee_to_joints, + robot_observation_processor=robot_joints_to_ee_pose, + ) + + if events["rerecord_episode"]: + log_say("Re-recording episode") + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue + + dataset.save_episode() + episode_idx += 1 + +# Clean up +log_say("Stop recording") +robot.disconnect() +phone.disconnect() +dataset.push_to_hub() diff --git a/examples/phone_so100_replay.py b/examples/phone_so100_replay.py new file mode 100644 index 000000000..f44207789 --- /dev/null +++ b/examples/phone_so100_replay.py @@ -0,0 +1,106 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import time + +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.model.kinematics import RobotKinematics +from lerobot.processor.converters import to_output_robot_action +from lerobot.processor.pipeline import RobotProcessor +from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig +from lerobot.robots.so100_follower.robot_kinematic_processor import ( + AddRobotObservationAsComplimentaryData, + InverseKinematicsEEToJoints, +) +from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.utils.robot_utils import busy_wait +from lerobot.utils.utils import log_say + +EPISODE_IDX = 0 +HF_REPO_ID = "/" + +robot_config = SO100FollowerConfig( + port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", use_degrees=True +) +robot = SO100Follower(robot_config) +robot.connect() + +dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX]) +actions = dataset.hf_dataset.select_columns("action") + +# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf +kinematics_solver = RobotKinematics( + urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=list(robot.bus.motors.keys()), +) + + +# This method converts the action from the dataset to a transition for pipeline +def action_to_transition(action: dict): + act = {} + + # EE pose + for k in ("ee.x", "ee.y", "ee.z", "ee.wx", "ee.wy", "ee.wz"): + if k in action: + act[f"action.{k}"] = float(action[k]) + + # Gripper: your dataset has absolute position + if "gripper.pos" in action: + act["action.gripper.pos"] = float(action["gripper.pos"]) + + return { + "observation": None, + "action": act, + "reward": None, + "done": False, + "truncated": False, + "info": {}, + "complementary_data": {}, + } + + +# Build pipeline to convert ee pose action to joint action +robot_ee_to_joints = RobotProcessor( + steps=[ + AddRobotObservationAsComplimentaryData(robot=robot), + InverseKinematicsEEToJoints( + kinematics=kinematics_solver, + motor_names=list(robot.bus.motors.keys()), + initial_guess_current_joints=False, # Because replay is open loop + ), + ], + to_transition=action_to_transition, + to_output=to_output_robot_action, +) + +robot_ee_to_joints.reset() + +log_say(f"Replaying episode {EPISODE_IDX}") +for idx in range(dataset.num_frames): + t0 = time.perf_counter() + + ee_action = { + name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"]) + } + + joint_action = robot_ee_to_joints(ee_action) + action_sent = robot.send_action(joint_action) + + busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0)) + +robot.disconnect() diff --git a/examples/phone_so100_teleop.py b/examples/phone_so100_teleop.py new file mode 100644 index 000000000..82515c98f --- /dev/null +++ b/examples/phone_so100_teleop.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specif + +import time + +from lerobot.model.kinematics import RobotKinematics +from lerobot.processor import RobotProcessor +from lerobot.processor.converters import to_output_robot_action, to_transition_teleop_action +from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig +from lerobot.robots.so100_follower.robot_kinematic_processor import ( + AddRobotObservationAsComplimentaryData, + EEBoundsAndSafety, + EEReferenceAndDelta, + GripperVelocityToJoint, + InverseKinematicsEEToJoints, +) +from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS +from lerobot.teleoperators.phone.phone import Phone +from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction + +# Initialize the robot and teleoperator +robot_config = SO100FollowerConfig( + port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", use_degrees=True +) +teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID + +# Initialize the robot and teleoperator +robot = SO100Follower(robot_config) +teleop_device = Phone(teleop_config) + +# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf +kinematics_solver = RobotKinematics( + urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=list(robot.bus.motors.keys()), +) + +# Build pipeline to convert phone action to ee pose action +phone_to_robot_ee_pose = RobotProcessor( + steps=[ + MapPhoneActionToRobotAction(platform=teleop_config.phone_os), + AddRobotObservationAsComplimentaryData(robot=robot), + EEReferenceAndDelta( + kinematics=kinematics_solver, + end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5}, + motor_names=list(robot.bus.motors.keys()), + ), + EEBoundsAndSafety( + end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, + max_ee_step_m=0.10, + max_ee_twist_step_rad=0.50, + ), + ], + to_transition=to_transition_teleop_action, + to_output=lambda tr: tr, +) + +# Build pipeline to convert ee pose action to joint action +robot_ee_to_joints = RobotProcessor( + steps=[ + InverseKinematicsEEToJoints( + kinematics=kinematics_solver, + motor_names=list(robot.bus.motors.keys()), + ), + GripperVelocityToJoint( + motor_names=list(robot.bus.motors.keys()), + speed_factor=20.0, + ), + ], + to_transition=lambda tr: tr, + to_output=to_output_robot_action, +) + +robot.connect() +teleop_device.connect() + +print("Starting teleop loop. Move your phone to teleoperate the robot.") +while True: + phone_obs = teleop_device.get_action() + if not phone_obs: + time.sleep(0.01) + continue + + # Get teleop observation + phone_obs = teleop_device.get_action() + + # Phone to EE pose transition + ee_transition = phone_to_robot_ee_pose(phone_obs) + + # EE pose to Joints transition + joint_action = robot_ee_to_joints(ee_transition) + + if joint_action: + robot.send_action(joint_action) + + time.sleep(0.01) diff --git a/pyproject.toml b/pyproject.toml index 968005281..bdd634f71 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,6 +111,7 @@ intelrealsense = [ "pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'", "pyrealsense2-macosx>=2.54 ; sys_platform == 'darwin'", ] +phone = ["hebi-py>=2.8.0", "teleop>=0.1.0"] # stretch = [ # "hello-robot-stretch-body>=0.7.27 ; sys_platform == 'linux'", # "pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'", @@ -152,7 +153,8 @@ all = [ "lerobot[video_benchmark]", "lerobot[aloha]", "lerobot[pusht]", - "lerobot[xarm]" + "lerobot[xarm]", + "lerobot[phone]", ] [project.scripts] diff --git a/src/lerobot/datasets/pipeline_features.py b/src/lerobot/datasets/pipeline_features.py new file mode 100644 index 000000000..fef75b407 --- /dev/null +++ b/src/lerobot/datasets/pipeline_features.py @@ -0,0 +1,94 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Sequence +from typing import Any + +from lerobot.datasets.utils import hw_to_dataset_features +from lerobot.processor.pipeline import RobotProcessor + + +def aggregate_pipeline_dataset_features( + pipeline: RobotProcessor, + initial_features: dict[str, Any], + *, + use_videos: bool = True, + patterns: Sequence[str] | None = None, +) -> dict[str, dict]: + """ + Aggregates the pipeline's features and returns a features dict ready for the dataset, + filtered to only those keys matching any of the given patterns (for action/state only). + + - `initial_features`: raw camera specs, e.g. {"front": (h,w,c), ...} + - `use_videos`: whether to treat image features as video streams + - `patterns`: regexes to filter action & state features; images are included + whenever use_videos=True, regardless of patterns. + """ + import re + + # Gather everything the pipeline features specifies, seeded with hardware cams: + all_features = pipeline.transform_features(initial_features) + + # Helper to decide which action/state keys survive the `patterns` filter: + def keep(key: str) -> bool: + if patterns is None: + return True + return any(re.search(pat, key) for pat in patterns) + + # Start with hardware dict, injecting initial cameras if videos are ON: + hw: dict[str, dict[str, Any]] = {} + if use_videos: + cams = { + name: shape + for name, shape in initial_features.items() + if isinstance(shape, tuple) and len(shape) == 3 + } + if cams: + hw["observation"] = dict(cams) + + # Go over every feature from the pipeline and merge: + for full_key, ty in all_features.items(): + if full_key.startswith("action."): + # action. + if not keep(full_key): + continue + name = full_key[len("action.") :] + hw.setdefault("action", {})[name] = ty + + elif full_key.startswith("observation.state."): + # observation.state. + if not keep(full_key): + continue + name = full_key[len("observation.state.") :] + hw.setdefault("observation", {})[name] = ty + + elif full_key.startswith("observation.images."): + # observation.images. + # images obey ONLY the use_videos flag, not patterns + if not use_videos: + continue + name = full_key[len("observation.images.") :] + hw.setdefault("observation", {})[name] = ty + + else: + # anything else (e.g. policy-only features) is ignored here + continue + + out: dict[str, dict] = {} + if "action" in hw: + out.update(hw_to_dataset_features(hw["action"], "action", use_videos)) + if "observation" in hw: + out.update(hw_to_dataset_features(hw["observation"], "observation", use_videos)) + + return out diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 078c5351d..db60e63b3 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -470,6 +470,50 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea return policy_features +def merge_features(*dicts: dict) -> dict: + """ + Merge LeRobot grouped feature dicts. + + - For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape. + - For others (observation.images.*), last one wins (if they are identical). + """ + out: dict = {} + for d in dicts: + for key, value in d.items(): + if not isinstance(value, dict): + out[key] = value + continue + + dtype = value.get("dtype") + shape = value.get("shape") + is_vector = ( + dtype not in ("image", "video", "string") + and isinstance(shape, tuple) + and len(shape) == 1 + and "names" in value + ) + + if is_vector: + # Initialize or retrieve the accumulating dict for this feature key + target = out.setdefault(key, {"dtype": dtype, "names": [], "shape": (0,)}) + # Ensure consistent data types across merged entries + if "dtype" in target and dtype != target["dtype"]: + raise ValueError(f"dtype mismatch for '{key}': {target['dtype']} vs {dtype}") + + # Merge feature names: append only new ones to preserve order without duplicates + seen = set(target["names"]) + for n in value["names"]: + if n not in seen: + target["names"].append(n) + seen.add(n) + # Recompute the shape to reflect the updated number of features + target["shape"] = (len(target["names"]),) + else: + # For images/videos and non-1D entries: override with the latest definition + out[key] = value + return out + + def create_empty_dataset_info( codebase_version: str, fps: int, diff --git a/src/lerobot/policies/pi0/processor_pi0.py b/src/lerobot/policies/pi0/processor_pi0.py index 06cb9848a..4c411dd66 100644 --- a/src/lerobot/policies/pi0/processor_pi0.py +++ b/src/lerobot/policies/pi0/processor_pi0.py @@ -65,8 +65,8 @@ class Pi0NewLineProcessor(ProcessorStep): return transition - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - """Add tokenized task features to the feature contract.""" + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + """Add tokenized task features to the features.""" return features def state_dict(self) -> dict[str, torch.Tensor]: diff --git a/src/lerobot/policies/smolvla/processor_smolvla.py b/src/lerobot/policies/smolvla/processor_smolvla.py index 5a8caec60..2c0221f9e 100644 --- a/src/lerobot/policies/smolvla/processor_smolvla.py +++ b/src/lerobot/policies/smolvla/processor_smolvla.py @@ -88,8 +88,8 @@ class SmolVLANewLineProcessor(ProcessorStep): return transition - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - """Add tokenized task features to the feature contract.""" + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + """Adds nothing to the features.""" return features def state_dict(self) -> dict[str, torch.Tensor]: diff --git a/src/lerobot/processor/batch_processor.py b/src/lerobot/processor/batch_processor.py index 40017760b..8a74afd3e 100644 --- a/src/lerobot/processor/batch_processor.py +++ b/src/lerobot/processor/batch_processor.py @@ -17,6 +17,7 @@ from typing import Any import torch from torch import Tensor +from lerobot.configs.types import PolicyFeature from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey @@ -134,6 +135,5 @@ class ToBatchProcessor: """Reset processor state (no-op for this processor).""" pass - def feature_contract(self, features: dict[str, Any]) -> dict[str, Any]: - """Return features (no-op for this processor).""" + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: return features diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py new file mode 100644 index 000000000..f0e081577 --- /dev/null +++ b/src/lerobot/processor/converters.py @@ -0,0 +1,225 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Iterable, Sequence +from copy import deepcopy +from typing import Any + +import numpy as np +import torch +from scipy.spatial.transform import Rotation + +from .pipeline import EnvTransition, TransitionKey + + +def _to_tensor(x: torch.Tensor | np.ndarray | Sequence[int | float]): + if isinstance(x, torch.Tensor): + return x + if isinstance(x, np.ndarray): + # Keep images (uint8 HWC) and python objects as-is + if x.dtype == np.uint8 or x.dtype == np.object_: + return x + # Scalars/arrays to float32 tensor + return torch.as_tensor(x, dtype=torch.float32) + # Anything else to float32 tensor + return torch.as_tensor(x, dtype=torch.float32) + + +def _from_tensor(x: Any): + if isinstance(x, torch.Tensor): + return x.item() if x.numel() == 1 else x.detach().cpu().numpy() + return x + + +def _is_image(arr: Any) -> bool: + return isinstance(arr, np.ndarray) and arr.dtype == np.uint8 and arr.ndim == 3 + + +def _split_obs_to_state_and_images(obs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: + state, images = {}, {} + for k, v in obs.items(): + if _is_image(v): + images[k] = v + else: + state[k] = v + return state, images + + +def make_obs_act_transition( + *, obs: dict[str, Any] | None = None, act: dict[str, Any] | None = None +) -> EnvTransition: + return { + TransitionKey.OBSERVATION: {} if obs is None else obs, + TransitionKey.ACTION: {} if act is None else act, + TransitionKey.INFO: {}, + TransitionKey.COMPLEMENTARY_DATA: {}, + TransitionKey.REWARD: None, + TransitionKey.DONE: None, + TransitionKey.TRUNCATED: None, + } + + +def to_transition_teleop_action(action: dict[str, Any]) -> EnvTransition: + """ + Convert a raw teleop action dict into an EnvTransition under the ACTION TransitionKey. + """ + act_dict: dict[str, Any] = {} + for k, v in action.items(): + # Check if the value is a type that should not be converted to a tensor. + if isinstance(v, (Rotation, dict)): + act_dict[f"action.{k}"] = v + continue + + arr = np.array(v) if np.isscalar(v) else v + act_dict[f"action.{k}"] = _to_tensor(arr) + + return make_obs_act_transition(act=act_dict) + + +# TODO(Adil, Pepijn): Overtime we can maybe add these converters to pipeline.py itself +def to_transition_robot_observation(observation: dict[str, Any]) -> EnvTransition: + """ + Convert a raw robot observation dict into an EnvTransition under the OBSERVATION TransitionKey. + """ + state, images = _split_obs_to_state_and_images(observation) + + obs_dict: dict[str, Any] = {} + for k, v in state.items(): + arr = np.array(v) if np.isscalar(v) else v + obs_dict[f"observation.state.{k}"] = _to_tensor(arr) + + for cam, img in images.items(): + obs_dict[f"observation.images.{cam}"] = img + + return make_obs_act_transition(obs=obs_dict) + + +def to_output_robot_action(transition: EnvTransition) -> dict[str, Any]: + """ + Converts a EnvTransition under the ACTION TransitionKey to a dict with keys ending in '.pos' for raw robot actions. + """ + out: dict[str, Any] = {} + action_dict = transition.get(TransitionKey.ACTION) or {} + + for k, v in action_dict.items(): + if isinstance(k, str) and k.startswith("action.") and k.endswith((".pos", ".vel")): + out_key = k[len("action.") :] # Strip the 'action.' prefix. + out[out_key] = float(v) + + return out + + +def to_dataset_frame( + transitions_or_transition: EnvTransition | Iterable[EnvTransition], features: dict[str, dict] +) -> dict[str, any]: + """ + Converts a single EnvTransition or an iterable of them into a flat, + dataset-friendly dictionary for training or evaluation, according to + the provided `features` spec. + + Args: + transitions_or_transition: Either a single EnvTransition dict + or an iterable of them (which will be merged). + features (dict[str, dict]): + A feature specification dictionary: + - 'action': dict with 'names': list of action feature names + - 'observation.state': dict with 'names': list of state feature names + - keys starting with 'observation.images.' are passed through + + Returns: + batch (dict[str, any]): Flat dictionary containing: + - numpy arrays for "observation.state" and "action" + - any image tensors defined in features + - next.{reward,done,truncated} + - info dict + - *_is_pad flags and task from complementary_data + """ + action_names = features.get("action", {}).get("names", []) + obs_state_names = features.get("observation.state", {}).get("names", []) + image_keys = [k for k in features if k.startswith("observation.images.")] + + def _merge(base: EnvTransition, other: EnvTransition) -> EnvTransition: + out = deepcopy(base) + for key in ( + TransitionKey.OBSERVATION, + TransitionKey.ACTION, + TransitionKey.INFO, + TransitionKey.COMPLEMENTARY_DATA, + ): + if other.get(key): + out.setdefault(key, {}).update(deepcopy(other[key])) + for k in (TransitionKey.REWARD, TransitionKey.DONE, TransitionKey.TRUNCATED): + if k in other: + out[k] = other[k] + return out + + def _ensure_transition(obj) -> EnvTransition: + # single transition + if isinstance(obj, dict) and any(isinstance(k, TransitionKey) for k in obj): + return obj + # iterable of transitions + if isinstance(obj, Iterable): + items = list(obj) + if not items: + return {} + acc = items[0] + for t in items[1:]: + acc = _merge(acc, t) + return acc + raise TypeError("Expected EnvTransition or iterable of them") + + tr = _ensure_transition(transitions_or_transition) + obs = tr.get(TransitionKey.OBSERVATION, {}) or {} + act = tr.get(TransitionKey.ACTION, {}) or {} + batch: dict[str, any] = {} + + # Images passthrough + for k in image_keys: + if k in obs: + batch[k] = obs[k] + + # Observation.state vector + if obs_state_names: + vals = [_from_tensor(obs.get(f"observation.state.{n}", 0.0)) for n in obs_state_names] + batch["observation.state"] = np.asarray(vals, dtype=np.float32) + + # Action vector + if action_names: + vals = [_from_tensor(act.get(f"action.{n}", 0.0)) for n in action_names] + batch["action"] = np.asarray(vals, dtype=np.float32) + + # Next.* fields + if tr.get(TransitionKey.REWARD) is not None: + batch["next.reward"] = _from_tensor(tr[TransitionKey.REWARD]) + if tr.get(TransitionKey.DONE) is not None: + batch["next.done"] = _from_tensor(tr[TransitionKey.DONE]) + if tr.get(TransitionKey.TRUNCATED) is not None: + batch["next.truncated"] = _from_tensor(tr[TransitionKey.TRUNCATED]) + + # Complementary data flags and task + comp = tr.get(TransitionKey.COMPLEMENTARY_DATA) or {} + if comp: + # pad flags + for k, v in comp.items(): + if k.endswith("_is_pad"): + batch[k] = v + # task label + if comp.get("task") is not None: + batch["task"] = comp["task"] + + return batch diff --git a/src/lerobot/processor/device_processor.py b/src/lerobot/processor/device_processor.py index 12f9a5abc..39bd1cf11 100644 --- a/src/lerobot/processor/device_processor.py +++ b/src/lerobot/processor/device_processor.py @@ -141,5 +141,5 @@ class DeviceProcessor: """Reset processor state (no-op for this processor).""" pass - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: return features diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index 94390b004..92e654472 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -257,7 +257,7 @@ class NormalizerProcessor: def reset(self): pass - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: return features @@ -435,7 +435,7 @@ class UnnormalizerProcessor: def reset(self): pass - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: return features diff --git a/src/lerobot/processor/observation_processor.py b/src/lerobot/processor/observation_processor.py index 7d63db238..40273548e 100644 --- a/src/lerobot/processor/observation_processor.py +++ b/src/lerobot/processor/observation_processor.py @@ -106,9 +106,8 @@ class VanillaObservationProcessor(ObservationProcessor): def observation(self, observation): return self._process_observation(observation) - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: """Transforms feature keys to a standardized contract. - This method handles several renaming patterns: - Exact matches (e.g., 'pixels' -> 'OBS_IMAGE'). - Prefixed exact matches (e.g., 'observation.pixels' -> 'OBS_IMAGE'). diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index 6d3546035..19dc668f7 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -23,7 +23,7 @@ from copy import deepcopy from dataclasses import dataclass, field from enum import Enum from pathlib import Path -from typing import Any, Protocol, TypedDict +from typing import Any, Protocol, TypedDict, runtime_checkable import torch from huggingface_hub import ModelHubMixin, hf_hub_download @@ -132,6 +132,7 @@ class ProcessorStepRegistry: cls._registry.clear() +@runtime_checkable class ProcessorStep(Protocol): """Structural typing interface for a single processor step. @@ -145,7 +146,6 @@ class ProcessorStep(Protocol): **Required**: - ``__call__(transition: EnvTransition) -> EnvTransition`` - - ``feature_contract(features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]`` Optional helper protocol: * ``get_config() -> dict[str, Any]`` – User-defined JSON-serializable @@ -158,6 +158,8 @@ class ProcessorStep(Protocol): * ``load_state_dict(state)`` – Inverse of ``state_dict``. Receives a dict containing torch tensors only. * ``reset()`` – Clear internal buffers at episode boundaries. + * ``transform_features(features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]`` + If present, this method will be called to aggregate the dataset features of all steps. Example separation: - get_config(): {"name": "my_step", "learning_rate": 0.01, "window_size": 10} @@ -174,7 +176,7 @@ class ProcessorStep(Protocol): def reset(self) -> None: ... - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: ... + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: ... def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noqa: D401 @@ -354,7 +356,10 @@ class RobotProcessor(ModelHubMixin): hook(idx, current_transition) # Convert back to original format if needed - return self.to_output(current_transition) if called_with_batch else current_transition + if called_with_batch or self.to_output is not _default_transition_to_batch: + return self.to_output(current_transition) + else: + return current_transition def _prepare_transition(self, data: EnvTransition | dict[str, Any]) -> tuple[EnvTransition, bool]: """Prepare and validate transition data for processing. @@ -819,23 +824,15 @@ class RobotProcessor(ModelHubMixin): f"Step {i} ({type(step).__name__}) must define __call__(transition) -> EnvTransition" ) - fc = getattr(step, "feature_contract", None) - if not callable(fc): - raise TypeError( - f"Step {i} ({type(step).__name__}) must define feature_contract(features) -> dict[str, Any]" - ) - - def feature_contract(self, initial_features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, initial_features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: """ - Apply ALL steps in order. Each step must implement - feature_contract(features) and return a dict (full or incremental schema). + Apply ALL steps in order. Only if a step has a features method, it will be called. + We aggregate the dataset features of all steps. """ features: dict[str, PolicyFeature] = deepcopy(initial_features) for _, step in enumerate(self.steps): - out = step.feature_contract(features) - if not isinstance(out, dict): - raise TypeError(f"{step.__class__.__name__}.feature_contract must return dict[str, Any]") + out = step.transform_features(features) features = out return features @@ -895,7 +892,7 @@ class ObservationProcessor: def reset(self) -> None: pass - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: return features @@ -955,7 +952,7 @@ class ActionProcessor: def reset(self) -> None: pass - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: return features @@ -1014,7 +1011,7 @@ class RewardProcessor: def reset(self) -> None: pass - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: return features @@ -1078,7 +1075,7 @@ class DoneProcessor: def reset(self) -> None: pass - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: return features @@ -1138,7 +1135,7 @@ class TruncatedProcessor: def reset(self) -> None: pass - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: return features @@ -1203,7 +1200,7 @@ class InfoProcessor: def reset(self) -> None: pass - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: return features @@ -1249,7 +1246,7 @@ class ComplementaryDataProcessor: def reset(self) -> None: pass - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: return features @@ -1271,5 +1268,5 @@ class IdentityProcessor: def reset(self) -> None: pass - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: return features diff --git a/src/lerobot/processor/rename_processor.py b/src/lerobot/processor/rename_processor.py index 4fe4105a5..db20424df 100644 --- a/src/lerobot/processor/rename_processor.py +++ b/src/lerobot/processor/rename_processor.py @@ -43,7 +43,7 @@ class RenameProcessor(ObservationProcessor): def get_config(self) -> dict[str, Any]: return {"rename_map": self.rename_map} - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: """Transforms: - Each key in the observation that appears in `rename_map` is renamed to its value. - Keys not in `rename_map` remain unchanged. diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py index c7086d6ce..4ec9fb351 100644 --- a/src/lerobot/processor/tokenizer_processor.py +++ b/src/lerobot/processor/tokenizer_processor.py @@ -187,7 +187,7 @@ class TokenizerProcessor: """Reset processor state (no-op for this processor).""" pass - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: """Add tokenized task features to the feature contract. Args: diff --git a/src/lerobot/record.py b/src/lerobot/record.py index e73c76384..78b671646 100644 --- a/src/lerobot/record.py +++ b/src/lerobot/record.py @@ -72,12 +72,19 @@ from lerobot.configs import parser from lerobot.configs.policies import PreTrainedConfig from lerobot.datasets.image_writer import safe_stop_image_writer from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features +from lerobot.datasets.utils import hw_to_dataset_features from lerobot.datasets.video_utils import VideoEncodingManager from lerobot.policies.factory import make_policy, make_processor from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.processor import RobotProcessor +from lerobot.processor.converters import ( + to_dataset_frame, + to_output_robot_action, + to_transition_robot_observation, + to_transition_teleop_action, +) from lerobot.processor.normalize_processor import rename_stats +from lerobot.processor.pipeline import IdentityProcessor, TransitionKey from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, @@ -191,6 +198,36 @@ class RecordConfig: return ["policy"] +""" --------------- record_loop() data flow -------------------------- + [ Robot ] + V + [ robot.get_observation() ] ---> raw_obs + V + [ robot_observation_processor ] ---> obs_transition + V + .-----( ACTION LOGIC )------------------. + V V + [ From Teleoperator ] [ From Policy ] + | | + | [teleop.get_action] -> raw_action | [predict_action] + | | | | + | V | V + | [teleop_action_processor] | | + | | | | + '---> teleop_transition '---> policy_transition + | | + '-------------------------.-------------' + V + [ robot_action_processor ] --> robot_action_to_send + V + [ robot.send_action() ] -- (Robot Executes) + V + ( Transitions are merged & added to Dataset ) + V + ( Rerun Log / Loop Wait ) +""" + + @safe_stop_image_writer def record_loop( robot: Robot, @@ -202,14 +239,27 @@ def record_loop( preprocessor: RobotProcessor | None = None, postprocessor: RobotProcessor | None = None, control_time_s: int | None = None, + teleop_action_processor: RobotProcessor | None = None, # runs after teleop + robot_action_processor: RobotProcessor | None = None, # runs before robot + robot_observation_processor: RobotProcessor | None = None, # runs after robot single_task: str | None = None, display_data: bool = False, ): + teleop_action_processor = teleop_action_processor or RobotProcessor( + steps=[IdentityProcessor()], to_transition=to_transition_teleop_action, to_output=lambda tr: tr + ) + robot_action_processor = robot_action_processor or RobotProcessor( + steps=[IdentityProcessor()], to_transition=lambda tr: tr, to_output=to_output_robot_action + ) + robot_observation_processor = robot_observation_processor or RobotProcessor( + steps=[IdentityProcessor()], to_transition=to_transition_robot_observation, to_output=lambda tr: tr + ) + if dataset is not None and dataset.fps != fps: raise ValueError(f"The dataset fps should be equal to requested fps ({dataset.fps} != {fps}).") teleop_arm = teleop_keyboard = None - if isinstance(teleop, list): + if isinstance(teleop, list): # For LeKiwi teleop_keyboard = next((t for t in teleop if isinstance(t, KeyboardTeleop)), None) teleop_arm = next( ( @@ -226,11 +276,20 @@ def record_loop( ) # Reset policy and processor if they are provided - if policy is not None or preprocessor is not None: + if policy is not None and preprocessor is not None and postprocessor is not None: policy.reset() preprocessor.reset() postprocessor.reset() + # Reset custom pipelines + teleop_action_processor.reset() + robot_action_processor.reset() + robot_observation_processor.reset() + + policy_transition = None + teleop_transition = None + obs_transition = None + timestamp = 0 start_episode_t = time.perf_counter() while timestamp < control_time_s: @@ -240,12 +299,19 @@ def record_loop( events["exit_early"] = False break - observation = robot.get_observation() + # Get robot observation + obs = robot.get_observation() - if policy is not None or dataset is not None: - observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation") + # Applies a pipeline to the raw robot observation, default is IdentityProcessor + obs_transition = robot_observation_processor(obs) + + # Get action from either policy or teleop + if policy is not None and preprocessor is not None and postprocessor is not None: + if dataset is not None: + observation_frame = to_dataset_frame( + obs_transition, dataset.features + ) # Convert the observation to the dataset format - if policy is not None or preprocessor is not None: action_values = predict_action( observation=observation_frame, policy=policy, @@ -256,37 +322,64 @@ def record_loop( task=single_task, robot_type=robot.robot_type, ) - action = {key: action_values[i].item() for i, key in enumerate(robot.action_features)} - elif policy is None and isinstance(teleop, Teleoperator): - action = teleop.get_action() - elif policy is None and isinstance(teleop, list): - # TODO(pepijn, steven): clean the record loop for use of multiple robots (possibly with pipeline) + + action_names = dataset.features["action"]["names"] + policy_action = {f"action.{name}": float(action_values[i]) for i, name in enumerate(action_names)} + policy_transition = { + TransitionKey.ACTION: policy_action, + TransitionKey.COMPLEMENTARY_DATA: {}, + } + + elif isinstance(teleop, Teleoperator): + act = teleop.get_action() + + # Applies a pipeline to the raw teleop action, default is IdentityProcessor + teleop_transition = teleop_action_processor(act) + + elif isinstance(teleop, list): arm_action = teleop_arm.get_action() arm_action = {f"arm_{k}": v for k, v in arm_action.items()} - keyboard_action = teleop_keyboard.get_action() base_action = robot._from_keyboard_to_base_action(keyboard_action) - - action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action + act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action + teleop_transition = teleop_action_processor(act) else: logging.info( - "No policy or teleoperator provided, skipping action generation." - "This is likely to happen when resetting the environment without a teleop device." - "The robot won't be at its rest position at the start of the next episode." + "No policy or teleoperator provided, skipping action generation. " + "This is likely to happen during environment reset." ) - continue + # Still continue to next loop to respect timing + # Applies a pipeline to the action, default is IdentityProcessor + # IMPORTANT: action_pipeline.to_output must return a dict suitable for robot.send_action() + if policy_transition is not None: + robot_action_to_send = robot_action_processor(policy_transition) + else: + robot_action_to_send = robot_action_processor(teleop_transition) + + # Send action to robot # Action can eventually be clipped using `max_relative_target`, # so action actually sent is saved in the dataset. action = postprocessor.process(action) - sent_action = robot.send_action(action) + # TODO(pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot. + _ = robot.send_action(robot_action_to_send) + # Write to dataset if dataset is not None: - action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action") - frame = {**observation_frame, **action_frame} + # If to_dataset_frame is provided, use it to merge the transitions. + merged = [] + if obs_transition is not None: # The observation from the robot + merged.append(obs_transition) + if teleop_transition is not None: # The action from teleop + merged.append(teleop_transition) + if policy_transition is not None: # The action from policy + merged.append(policy_transition) + frame = to_dataset_frame( + merged if len(merged) > 1 else merged[0], dataset.features + ) # Convert the observation to the dataset format dataset.add_frame(frame, task=single_task) if display_data: - log_rerun_data(observation, action) + log_rerun_data([obs_transition, teleop_transition or policy_transition]) dt_s = time.perf_counter() - start_loop_t busy_wait(1 / fps - dt_s) @@ -417,9 +510,5 @@ def record(cfg: RecordConfig) -> LeRobotDataset: return dataset -def main(): - record() - - if __name__ == "__main__": - main() + record() diff --git a/src/lerobot/robots/so100_follower/__init__.py b/src/lerobot/robots/so100_follower/__init__.py index b995aab13..5dc43ac3b 100644 --- a/src/lerobot/robots/so100_follower/__init__.py +++ b/src/lerobot/robots/so100_follower/__init__.py @@ -14,6 +14,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .config_so100_follower import SO100FollowerConfig, SO100FollowerEndEffectorConfig +from .config_so100_follower import SO100FollowerConfig from .so100_follower import SO100Follower -from .so100_follower_end_effector import SO100FollowerEndEffector diff --git a/src/lerobot/robots/so100_follower/config_so100_follower.py b/src/lerobot/robots/so100_follower/config_so100_follower.py index ea8b9f1c2..16bab13e4 100644 --- a/src/lerobot/robots/so100_follower/config_so100_follower.py +++ b/src/lerobot/robots/so100_follower/config_so100_follower.py @@ -39,35 +39,3 @@ class SO100FollowerConfig(RobotConfig): # Set to `True` for backward compatibility with previous policies/dataset use_degrees: bool = False - - -@RobotConfig.register_subclass("so100_follower_end_effector") -@dataclass -class SO100FollowerEndEffectorConfig(SO100FollowerConfig): - """Configuration for the SO100FollowerEndEffector robot.""" - - # Path to URDF file for kinematics - # NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: - # https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf - urdf_path: str | None = None - - # End-effector frame name in the URDF - target_frame_name: str = "gripper_frame_link" - - # Default bounds for the end-effector position (in meters) - end_effector_bounds: dict[str, list[float]] = field( - default_factory=lambda: { - "min": [-1.0, -1.0, -1.0], # min x, y, z - "max": [1.0, 1.0, 1.0], # max x, y, z - } - ) - - max_gripper_pos: float = 50 - - end_effector_step_sizes: dict[str, float] = field( - default_factory=lambda: { - "x": 0.02, - "y": 0.02, - "z": 0.02, - } - ) diff --git a/src/lerobot/robots/so100_follower/robot_kinematic_processor.py b/src/lerobot/robots/so100_follower/robot_kinematic_processor.py new file mode 100644 index 000000000..ed498557f --- /dev/null +++ b/src/lerobot/robots/so100_follower/robot_kinematic_processor.py @@ -0,0 +1,447 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +import numpy as np +from scipy.spatial.transform import Rotation + +from lerobot.configs.types import PolicyFeature +from lerobot.model.kinematics import RobotKinematics +from lerobot.processor.pipeline import ( + ActionProcessor, + ComplementaryDataProcessor, + EnvTransition, + ObservationProcessor, + ProcessorStepRegistry, + TransitionKey, +) +from lerobot.robots.robot import Robot + + +@ProcessorStepRegistry.register("ee_reference_and_delta") +@dataclass +class EEReferenceAndDelta: + """ + Compute the desired end-effector pose from the target pose and the current pose. + + Input ACTION keys: + { + "action.ee.{x,y,z,wx,wy,wz}" : float + "complementary_data.raw_joint_positions": dict, + } + + Output ACTION keys: + { + "action.ee.{x,y,z,wx,wy,wz}" : float + } + """ + + kinematics: RobotKinematics + end_effector_step_sizes: dict + motor_names: list[str] + + reference_ee_pose: np.ndarray | None = field(default=None, init=False, repr=False) + _prev_enabled: bool = field(default=False, init=False, repr=False) + _command_when_disabled: np.ndarray | None = field(default=None, init=False, repr=False) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + act = transition.get(TransitionKey.ACTION) or {} + comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {} + + # Get joint positions from complimentary data + raw = comp.get("raw_joint_positions", None) + if raw is None: + raise ValueError( + "raw_joint_positions is not in complementary data and is required for EEReferenceAndDelta" + ) + + q = np.array([float(raw[n]) for n in self.motor_names], dtype=float) + + # Current pose from FK on measured joints + t_curr = self.kinematics.forward_kinematics(q) + + enabled = bool(act.pop("action.enabled", 0)) + tx = float(act.pop("action.target_x", 0.0)) + ty = float(act.pop("action.target_y", 0.0)) + tz = float(act.pop("action.target_z", 0.0)) + wx = float(act.pop("action.target_wx", 0.0)) + wy = float(act.pop("action.target_wy", 0.0)) + wz = float(act.pop("action.target_wz", 0.0)) + + desired = None + + if enabled: + # Latch a reference at the rising edge; also be defensive if None + if not self._prev_enabled or self.reference_ee_pose is None: + self.reference_ee_pose = t_curr.copy() + + ref = self.reference_ee_pose if self.reference_ee_pose is not None else t_curr + + delta_p = np.array( + [ + tx * self.end_effector_step_sizes["x"], + ty * self.end_effector_step_sizes["y"], + tz * self.end_effector_step_sizes["z"], + ], + dtype=float, + ) + r_abs = Rotation.from_rotvec([wx, wy, wz]).as_matrix() + + desired = np.eye(4, dtype=float) + desired[:3, :3] = ref[:3, :3] @ r_abs + desired[:3, 3] = ref[:3, 3] + delta_p + + self._command_when_disabled = desired.copy() + else: + # While disabled, keep sending the same command to avoid drift. + if self._command_when_disabled is None: + # If we've never had an enabled command yet, freeze current FK pose once. + self._command_when_disabled = t_curr.copy() + desired = self._command_when_disabled.copy() + + # Write action fields + pos = desired[:3, 3] + tw = Rotation.from_matrix(desired[:3, :3]).as_rotvec() + act.update( + { + "action.ee.x": float(pos[0]), + "action.ee.y": float(pos[1]), + "action.ee.z": float(pos[2]), + "action.ee.wx": float(tw[0]), + "action.ee.wy": float(tw[1]), + "action.ee.wz": float(tw[2]), + } + ) + + self._prev_enabled = enabled + transition[TransitionKey.ACTION] = act + return transition + + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +@ProcessorStepRegistry.register("ee_bounds_and_safety") +@dataclass +class EEBoundsAndSafety(ActionProcessor): + """ + Clip the end-effector pose to the bounds and check for jumps. + + Input ACTION keys: + { + "action.ee.{x,y,z,wx,wy,wz}" : float + } + + Output ACTION keys: + { + "action.ee.{x,y,z,wx,wy,wz}" : float + } + """ + + end_effector_bounds: dict + max_ee_step_m: float = 0.05 + max_ee_twist_step_rad: float = 0.20 + _last_pos: np.ndarray | None = field(default=None, init=False, repr=False) + + def action(self, act: dict | None) -> dict: + x = act.pop("action.ee.x", None) + y = act.pop("action.ee.y", None) + z = act.pop("action.ee.z", None) + wx = act.pop("action.ee.wx", None) + wy = act.pop("action.ee.wy", None) + wz = act.pop("action.ee.wz", None) + + if None in (x, y, z, wx, wy, wz): + return act + + pos = np.array([x, y, z], dtype=float) + twist = np.array([wx, wy, wz], dtype=float) + + # Clip position + pos = np.clip(pos, self.end_effector_bounds["min"], self.end_effector_bounds["max"]) + + # Check for jumps in position + if self._last_pos is not None: + dpos = pos - self._last_pos + n = float(np.linalg.norm(dpos)) + if n > self.max_ee_step_m and n > 0: + pos = self._last_pos + dpos * (self.max_ee_step_m / n) + raise ValueError(f"EE jump {n:.3f}m > {self.max_ee_step_m}m") + + self._last_pos = pos + self._last_twist = twist + + act.update( + { + "action.ee.x": float(pos[0]), + "action.ee.y": float(pos[1]), + "action.ee.z": float(pos[2]), + "action.ee.wx": float(twist[0]), + "action.ee.wy": float(twist[1]), + "action.ee.wz": float(twist[2]), + } + ) + return act + + def reset(self): + self._last_pos = None + + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # Because this is last step we specify the dataset features of this step that we want to be stored in the dataset + features["action.ee.x"] = float + features["action.ee.y"] = float + features["action.ee.z"] = float + features["action.ee.wx"] = float + features["action.ee.wy"] = float + features["action.ee.wz"] = float + return features + + +@ProcessorStepRegistry.register("inverse_kinematics_ee_to_joints") +@dataclass +class InverseKinematicsEEToJoints: + """ + Compute the desired joint positions from the desired end-effector pose. + + Input ACTION keys: + { + "action.ee.{x,y,z,wx,wy,wz}" : float + "complementary_data.raw_joint_positions": dict, + } + + Output ACTION keys: + { + "action.joint_name_1.pos": float, + "action.joint_name_2.pos": float, + ... + "action.joint_name_n.pos": float, + } + """ + + kinematics: RobotKinematics + motor_names: list[str] + q_curr: np.ndarray | None = field(default=None, init=False, repr=False) + initial_guess_current_joints: bool = True + + def __call__(self, transition: EnvTransition) -> EnvTransition: + act = transition.get(TransitionKey.ACTION) or {} + comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {} + + x = act.get("action.ee.x", None) + y = act.get("action.ee.y", None) + z = act.get("action.ee.z", None) + wx = act.get("action.ee.wx", None) + wy = act.get("action.ee.wy", None) + wz = act.get("action.ee.wz", None) + + if None in (x, y, z, wx, wy, wz): + # Nothing to do; restore what we popped and return + act.update( + { + "action.ee.x": x, + "action.ee.y": y, + "action.ee.z": z, + "action.ee.wx": wx, + "action.ee.wy": wy, + "action.ee.wz": wz, + } + ) + transition[TransitionKey.ACTION] = act + return transition + + # Get joint positions from complimentary data + raw = comp.get("raw_joint_positions", None) + if raw is None: + raise ValueError( + "raw_joint_positions is not in complementary data and is required for EEReferenceAndDelta" + ) + + if self.initial_guess_current_joints: # Use current joints as initial guess + self.q_curr = np.array([float(raw[n]) for n in self.motor_names], dtype=float) + else: # Use previous ik solution as initial guess + if self.q_curr is None: + self.q_curr = np.array([float(raw[n]) for n in self.motor_names], dtype=float) + + # Build desired 4x4 transform from pos + rotvec (twist) + t_des = np.eye(4, dtype=float) + t_des[:3, :3] = Rotation.from_rotvec([wx, wy, wz]).as_matrix() + t_des[:3, 3] = [x, y, z] + + # Compute inverse kinematics + q_target = self.kinematics.inverse_kinematics(self.q_curr, t_des) + self.q_curr = q_target + + new_act = dict(act) + for i, name in enumerate(self.motor_names): + if name == "gripper": + new_act["observation.state.gripper.pos"] = float(raw["gripper"]) + else: + new_act[f"action.{name}.pos"] = float(q_target[i]) + transition[TransitionKey.ACTION] = new_act + return transition + + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We specify the dataset features of this step that we want to be stored in the dataset + features["action.ee.x"] = float + features["action.ee.y"] = float + features["action.ee.z"] = float + features["action.ee.wx"] = float + features["action.ee.wy"] = float + features["action.ee.wz"] = float + + features["observation.state.gripper.pos"] = float + features["action.gripper.pos"] = float + return features + + def reset(self): + self.q_curr = None + + +@ProcessorStepRegistry.register("gripper_velocity_to_joint") +@dataclass +class GripperVelocityToJoint: + """ + Convert the gripper velocity to a joint velocity. + + Input ACTION keys: + { + "action.gripper": float, + } + + Output ACTION keys: + { + "action.gripper.pos": float, + } + """ + + motor_names: list[str] + speed_factor: float = 20.0 + clip_min: float = 0.0 + clip_max: float = 100.0 + + def __call__(self, transition: EnvTransition) -> EnvTransition: + obs = transition.get(TransitionKey.OBSERVATION) or {} + act = transition.get(TransitionKey.ACTION) or {} + comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {} + + if "action.gripper" not in act: + return transition + + if "gripper" not in self.motor_names: + new_act = dict(act) + new_act.pop("action.gripper", None) + transition[TransitionKey.ACTION] = new_act + return transition + + # Get current gripper position from complementary data + raw = comp.get("raw_joint_positions") or {} + curr_pos = float(raw.get("gripper")) + + # Compute desired gripper velocity + u = float(act.get("action.gripper", 0.0)) + delta = u * float(self.speed_factor) + gripper_pos = float(np.clip(curr_pos + delta, self.clip_min, self.clip_max)) + + new_act = dict(act) + new_act["action.gripper.pos"] = gripper_pos + new_act.pop("action.gripper", None) + transition[TransitionKey.ACTION] = new_act + + obs.update({"observation.state.gripper.pos": curr_pos}) + transition[TransitionKey.OBSERVATION] = obs + return transition + + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We specify the dataset features of this step that we want to be stored in the dataset + features["observation.state.gripper.pos"] = float + features["action.gripper.pos"] = float + return features + + +@ProcessorStepRegistry.register("forward_kinematics_joints_to_ee") +@dataclass +class ForwardKinematicsJointsToEE(ObservationProcessor): + """ + Compute the end-effector pose from the joint positions. + + Input OBSERVATION keys: + { + "observation.state.{joint_name_1,joint_name_2,...,joint_name_n}.pos": float, + } + + Output OBSERVATION keys: + { + "observation.state.ee.{x,y,z,wx,wy,wz}" : float + } + """ + + kinematics: RobotKinematics + motor_names: list[str] + + def observation(self, obs: dict | None) -> dict: + if not all(f"observation.state.{n}.pos" in obs for n in self.motor_names): + return obs + + q = np.array([obs[f"observation.state.{n}.pos"] for n in self.motor_names], dtype=float) + t = self.kinematics.forward_kinematics(q) + pos = t[:3, 3] + tw = Rotation.from_matrix(t[:3, :3]).as_rotvec() + + obs.update( + { + "observation.state.ee.x": float(pos[0]), + "observation.state.ee.y": float(pos[1]), + "observation.state.ee.z": float(pos[2]), + "observation.state.ee.wx": float(tw[0]), + "observation.state.ee.wy": float(tw[1]), + "observation.state.ee.wz": float(tw[2]), + } + ) + return obs + + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We specify the dataset features of this step that we want to be stored in the dataset + for k in ["x", "y", "z", "wx", "wy", "wz"]: + features[f"observation.state.ee.{k}"] = float + return features + + +@ProcessorStepRegistry.register("add_robot_observation") +@dataclass +class AddRobotObservationAsComplimentaryData(ComplementaryDataProcessor): + """ + Read the robot's current observation and insert it into the transition as complementary data. + + - Joint positions are added under complementary_data["raw_joint_positions"] as a dict: + { "": , ... } + """ + + robot: Robot + + def complementary_data(self, comp: dict | None) -> dict: + comp = {} if comp is None else dict(comp) + obs = self.robot.get_observation() + + comp["raw_joint_positions"] = { + k.removesuffix(".pos"): float(v) + for k, v in obs.items() + if isinstance(k, str) and k.endswith(".pos") + } + return comp + + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features diff --git a/src/lerobot/robots/so100_follower/so100_follower_end_effector.py b/src/lerobot/robots/so100_follower/so100_follower_end_effector.py deleted file mode 100644 index 5fe2993cb..000000000 --- a/src/lerobot/robots/so100_follower/so100_follower_end_effector.py +++ /dev/null @@ -1,200 +0,0 @@ -# !/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import time -from typing import Any - -import numpy as np - -from lerobot.cameras import make_cameras_from_configs -from lerobot.errors import DeviceNotConnectedError -from lerobot.model.kinematics import RobotKinematics -from lerobot.motors import Motor, MotorNormMode -from lerobot.motors.feetech import FeetechMotorsBus - -from . import SO100Follower -from .config_so100_follower import SO100FollowerEndEffectorConfig - -logger = logging.getLogger(__name__) - - -class SO100FollowerEndEffector(SO100Follower): - """ - SO100Follower robot with end-effector space control. - - This robot inherits from SO100Follower but transforms actions from - end-effector space to joint space before sending them to the motors. - """ - - config_class = SO100FollowerEndEffectorConfig - name = "so100_follower_end_effector" - - def __init__(self, config: SO100FollowerEndEffectorConfig): - super().__init__(config) - self.bus = FeetechMotorsBus( - port=self.config.port, - motors={ - "shoulder_pan": Motor(1, "sts3215", MotorNormMode.DEGREES), - "shoulder_lift": Motor(2, "sts3215", MotorNormMode.DEGREES), - "elbow_flex": Motor(3, "sts3215", MotorNormMode.DEGREES), - "wrist_flex": Motor(4, "sts3215", MotorNormMode.DEGREES), - "wrist_roll": Motor(5, "sts3215", MotorNormMode.DEGREES), - "gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100), - }, - calibration=self.calibration, - ) - - self.cameras = make_cameras_from_configs(config.cameras) - - self.config = config - - # Initialize the kinematics module for the so100 robot - if self.config.urdf_path is None: - raise ValueError( - "urdf_path must be provided in the configuration for end-effector control. " - "Please set urdf_path in your SO100FollowerEndEffectorConfig." - ) - - self.kinematics = RobotKinematics( - urdf_path=self.config.urdf_path, - target_frame_name=self.config.target_frame_name, - ) - - # Store the bounds for end-effector position - self.end_effector_bounds = self.config.end_effector_bounds - - self.current_ee_pos = None - self.current_joint_pos = None - - @property - def action_features(self) -> dict[str, Any]: - """ - Define action features for end-effector control. - Returns dictionary with dtype, shape, and names. - """ - return { - "dtype": "float32", - "shape": (4,), - "names": {"delta_x": 0, "delta_y": 1, "delta_z": 2, "gripper": 3}, - } - - def send_action(self, action: dict[str, Any]) -> dict[str, Any]: - """ - Transform action from end-effector space to joint space and send to motors. - - Args: - action: Dictionary with keys 'delta_x', 'delta_y', 'delta_z' for end-effector control - or a numpy array with [delta_x, delta_y, delta_z] - - Returns: - The joint-space action that was sent to the motors - """ - - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - - # Convert action to numpy array if not already - if isinstance(action, dict): - if all(k in action for k in ["delta_x", "delta_y", "delta_z"]): - delta_ee = np.array( - [ - action["delta_x"] * self.config.end_effector_step_sizes["x"], - action["delta_y"] * self.config.end_effector_step_sizes["y"], - action["delta_z"] * self.config.end_effector_step_sizes["z"], - ], - dtype=np.float32, - ) - if "gripper" not in action: - action["gripper"] = [1.0] - action = np.append(delta_ee, action["gripper"]) - else: - logger.warning( - f"Expected action keys 'delta_x', 'delta_y', 'delta_z', got {list(action.keys())}" - ) - action = np.zeros(4, dtype=np.float32) - - if self.current_joint_pos is None: - # Read current joint positions - current_joint_pos = self.bus.sync_read("Present_Position") - self.current_joint_pos = np.array([current_joint_pos[name] for name in self.bus.motors]) - - # Calculate current end-effector position using forward kinematics - if self.current_ee_pos is None: - self.current_ee_pos = self.kinematics.forward_kinematics(self.current_joint_pos) - - # Set desired end-effector position by adding delta - desired_ee_pos = np.eye(4) - desired_ee_pos[:3, :3] = self.current_ee_pos[:3, :3] # Keep orientation - - # Add delta to position and clip to bounds - desired_ee_pos[:3, 3] = self.current_ee_pos[:3, 3] + action[:3] - if self.end_effector_bounds is not None: - desired_ee_pos[:3, 3] = np.clip( - desired_ee_pos[:3, 3], - self.end_effector_bounds["min"], - self.end_effector_bounds["max"], - ) - - # Compute inverse kinematics to get joint positions - target_joint_values_in_degrees = self.kinematics.inverse_kinematics( - self.current_joint_pos, desired_ee_pos - ) - - # Create joint space action dictionary - joint_action = { - f"{key}.pos": target_joint_values_in_degrees[i] for i, key in enumerate(self.bus.motors.keys()) - } - - # Handle gripper separately if included in action - # Gripper delta action is in the range 0 - 2, - # We need to shift the action to the range -1, 1 so that we can expand it to -Max_gripper_pos, Max_gripper_pos - joint_action["gripper.pos"] = np.clip( - self.current_joint_pos[-1] + (action[-1] - 1) * self.config.max_gripper_pos, - 5, - self.config.max_gripper_pos, - ) - - self.current_ee_pos = desired_ee_pos.copy() - self.current_joint_pos = target_joint_values_in_degrees.copy() - self.current_joint_pos[-1] = joint_action["gripper.pos"] - - # Send joint space action to parent class - return super().send_action(joint_action) - - def get_observation(self) -> dict[str, Any]: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - - # Read arm position - start = time.perf_counter() - obs_dict = self.bus.sync_read("Present_Position") - obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()} - dt_ms = (time.perf_counter() - start) * 1e3 - logger.debug(f"{self} read state: {dt_ms:.1f}ms") - - # Capture images from cameras - for cam_key, cam in self.cameras.items(): - start = time.perf_counter() - obs_dict[cam_key] = cam.async_read() - dt_ms = (time.perf_counter() - start) * 1e3 - logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") - - return obs_dict - - def reset(self): - self.current_ee_pos = None - self.current_joint_pos = None diff --git a/src/lerobot/robots/utils.py b/src/lerobot/robots/utils.py index 7486ee499..87e751b26 100644 --- a/src/lerobot/robots/utils.py +++ b/src/lerobot/robots/utils.py @@ -69,6 +69,7 @@ def make_robot_from_config(config: RobotConfig) -> Robot: raise ValueError(config.type) +# TODO(pepijn): Move to pipeline step to make sure we don't have to do this in the robot code and send action to robot is clean for use in dataset def ensure_safe_goal_position( goal_present_pos: dict[str, tuple[float, float]], max_relative_target: float | dict[float] ) -> dict[str, float]: diff --git a/src/lerobot/teleoperate.py b/src/lerobot/teleoperate.py index 3c72caf79..320140bdb 100644 --- a/src/lerobot/teleoperate.py +++ b/src/lerobot/teleoperate.py @@ -109,7 +109,7 @@ def teleop_loop( action = teleop.get_action() if display_data: observation = robot.get_observation() - log_rerun_data(observation, action) + log_rerun_data(observation=observation, action=action) robot.send_action(action) dt_s = time.perf_counter() - loop_start diff --git a/src/lerobot/teleoperators/phone/__init__.py b/src/lerobot/teleoperators/phone/__init__.py new file mode 100644 index 000000000..f82ab11e1 --- /dev/null +++ b/src/lerobot/teleoperators/phone/__init__.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config_phone import PhoneConfig +from .phone import Phone diff --git a/src/lerobot/teleoperators/phone/config_phone.py b/src/lerobot/teleoperators/phone/config_phone.py new file mode 100644 index 000000000..380d5f5ff --- /dev/null +++ b/src/lerobot/teleoperators/phone/config_phone.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from enum import Enum + +import numpy as np + +from ..config import TeleoperatorConfig + + +class PhoneOS(Enum): + ANDROID = "android" + IOS = "ios" + + +@TeleoperatorConfig.register_subclass("phone") +@dataclass +class PhoneConfig(TeleoperatorConfig): + phone_os: PhoneOS = PhoneOS.IOS + camera_offset = np.array( + [0.0, -0.02, 0.04] + ) # iPhone 14 Pro camera is 2cm off center and 4cm above center diff --git a/src/lerobot/teleoperators/phone/phone.py b/src/lerobot/teleoperators/phone/phone.py new file mode 100644 index 000000000..3c6d5fc5d --- /dev/null +++ b/src/lerobot/teleoperators/phone/phone.py @@ -0,0 +1,246 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Docs: +# hebi: https://docs.hebi.us/tools.html#mobile-io +# teleop: https://github.com/SpesRobotics/teleop + +import logging +import threading +import time + +import hebi +import numpy as np +from scipy.spatial.transform import Rotation +from teleop import Teleop + +from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS +from lerobot.teleoperators.teleoperator import Teleoperator + +logger = logging.getLogger(__name__) + + +class Phone(Teleoperator): + """ + Phone-based teleoperator using ARKit (iOS via HEBI Mobile I/O App) or the teleop Python package (Android via WebXR API). + For HEBI Mobile I/O we also expose 8 analog (a1-a8) and 8 digital (b1-b8) inputs. + + Press and hold **B1** to enable teleoperation. While enabled, the first B1 press + captures a reference pose and rotation, when disabled and pressed again the position is reapplied. + """ + + config_class = PhoneConfig + name = "phone" + + def __init__(self, config: PhoneConfig): + super().__init__(config) + self.config = config + self._group = None + self._teleop = None + self._teleop_thread = None + self._latest_pose = None + self._latest_message = None + self._enabled: bool = False + self._calib_pos: np.ndarray | None = None + self._calib_rot_inv: Rotation | None = None + + @property + def is_connected(self) -> bool: + return (self.config.phone_os == PhoneOS.IOS and self._group is not None) or ( + self.config.phone_os == PhoneOS.ANDROID and self._teleop is not None + ) + + def connect(self) -> None: + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} already connected") + + if self.config.phone_os == PhoneOS.IOS: + logger.info("Connecting to IPhone, make sure to open the HEBI Mobile I/O app.") + lookup = hebi.Lookup() + time.sleep(2.0) + group = lookup.get_group_from_names(["HEBI"], ["mobileIO"]) + if group is None: + raise RuntimeError("Mobile I/O not found — check name/family settings in the app.") + self._group = group + logger.info(f"{self} connected to HEBI group with {group.size} module(s).") + elif self.config.phone_os == PhoneOS.ANDROID: + logger.info("Starting teleop stream for Android...") + self._teleop = Teleop() + self._teleop.subscribe(self._android_callback) + self._teleop_thread = threading.Thread(target=self._teleop.run, daemon=True) + self._teleop_thread.start() + logger.info(f"{self} connected, teleop stream started.") + else: + raise ValueError(f"Invalid config phone_os: {self.config.phone_os}") + + self.calibrate() + + def calibrate(self) -> None: + print( + "Hold the phone so that: top edge points forward in same direction as the robot (robot +x) and screen points up (robot +z)" + ) + if self.config.phone_os == PhoneOS.IOS: + print("Press and hold B1 in the HEBI Mobile I/O app to capture this pose...\n") + else: + print("Touch and move on the WebXR page to capture this pose...\n") + + pos, rot = self._wait_for_capture_trigger() + self._calib_pos = pos.copy() + self._calib_rot_inv = rot.inv() + self._enabled = False + print("Calibration done\n") + + def _reapply_position_calibration(self, pos: np.ndarray) -> None: + self._calib_pos = pos.copy() + + @property + def is_calibrated(self) -> bool: + return (self._calib_pos is not None) and (self._calib_rot_inv is not None) + + @property + def action_features(self) -> dict[str, type]: + return { + "phone.pos": np.ndarray, # shape (3,) + "phone.rot": Rotation, # scipy.spatial.transform.Rotation + "phone.raw_inputs": dict, # analogs/buttons or webXR meta + "phone.enabled": bool, + } + + def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]: + """Wait trigger for calibration: iOS: B1. Android: 'move'.""" + while True: + ok, pos, rot, pose = self._read_current_pose() + if not ok: + time.sleep(0.01) + continue + + if self.config.phone_os == PhoneOS.IOS: + io = getattr(pose, "io", None) + b = getattr(io, "b", None) if io is not None else None + b1 = False + if b is not None: + b1 = bool(b.get_int(1)) + if b1: + return pos, rot + else: + msg = self._latest_message or {} + if bool(msg.get("move", False)): + return pos, rot + + time.sleep(0.01) + + def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]: + if self.config.phone_os == PhoneOS.IOS: + fbk = self._group.get_next_feedback() + pose = fbk[0] + ar_pos = getattr(pose, "ar_position", None) + ar_quat = getattr(pose, "ar_orientation", None) + if ar_pos is None or ar_quat is None: + return False, None, None, None + quat_xyzw = np.concatenate((ar_quat[1:], [ar_quat[0]])) # wxyz to xyzw + rot = Rotation.from_quat(quat_xyzw) + pos = ar_pos - rot.apply(self.config.camera_offset) + return True, pos, rot, pose + else: + p = self._latest_pose + if p is None: + return False, None, None, None + rot = Rotation.from_matrix(p[:3, :3]) + pos = p[:3, 3] - rot.apply(self.config.camera_offset) + pose = self._latest_pose + return True, pos, rot, pose + + @property + def feedback_features(self) -> dict[str, type]: + # No haptic or other feedback implemented yet + pass + + def configure(self) -> None: + # No additional configuration required for phone teleop + pass + + def _android_callback(self, pose: np.ndarray, message: dict) -> None: + self._latest_pose = pose + self._latest_message = message + time.sleep(0.001) # 1ms delay to avoid race condition + + def get_action(self) -> dict: + ok, raw_pos, raw_rot, pose = self._read_current_pose() + if not ok or not self.is_calibrated: + return {} + + # Collect raw inputs (B1 / analogs on iOS, move/scale on Android) + raw_inputs: dict[str, float | int | bool] = {} + if self.config.phone_os == PhoneOS.IOS: + io = getattr(pose, "io", None) + if io is not None: + bank_a, bank_b = io.a, io.b + if bank_a: + for ch in range(1, 9): + if bank_a.has_float(ch): + raw_inputs[f"a{ch}"] = float(bank_a.get_float(ch)) + if bank_b: + for ch in range(1, 9): + if bank_b.has_int(ch): + raw_inputs[f"b{ch}"] = int(bank_b.get_int(ch)) + elif hasattr(bank_b, "has_bool") and bank_b.has_bool(ch): + raw_inputs[f"b{ch}"] = int(bank_b.get_bool(ch)) + else: + msg = self._latest_message or {} + raw_inputs["move"] = bool(msg.get("move", False)) + raw_inputs["scale"] = float(msg.get("scale", 1.0)) + raw_inputs["reservedButtonA"] = bool(msg.get("reservedButtonA", False)) + raw_inputs["reservedButtonB"] = bool(msg.get("reservedButtonB", False)) + + if self.config.phone_os == PhoneOS.IOS: + enable = bool(raw_inputs.get("b1", 0)) + else: + enable = bool(raw_inputs.get("move", False)) + + # Rising edge then re-capture calibration immediately from current raw pose + if enable and not self._enabled: + self._reapply_position_calibration(raw_pos) + + # Apply calibration + pos_cal = self._calib_rot_inv.apply(raw_pos - self._calib_pos) + rot_cal = self._calib_rot_inv * raw_rot + + self._enabled = enable + + return { + "phone.pos": pos_cal, + "phone.rot": rot_cal, + "phone.raw_inputs": raw_inputs, + "phone.enabled": self._enabled, + } + + def send_feedback(self, feedback: dict[str, float]) -> None: + # We could add haptic feedback (vibrations) here, but it's not implemented yet + raise NotImplementedError + + def disconnect(self) -> None: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + if self.config.phone_os == PhoneOS.IOS: + self._group = None + else: + self._teleop = None + if self._teleop_thread and self._teleop_thread.is_alive(): + self._teleop_thread.join(timeout=1.0) + self._teleop_thread = None + self._latest_pose = None diff --git a/src/lerobot/teleoperators/phone/phone_processor.py b/src/lerobot/teleoperators/phone/phone_processor.py new file mode 100644 index 000000000..436ee8444 --- /dev/null +++ b/src/lerobot/teleoperators/phone/phone_processor.py @@ -0,0 +1,87 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.configs.types import PolicyFeature +from lerobot.processor.pipeline import ActionProcessor, ProcessorStepRegistry +from lerobot.teleoperators.phone.config_phone import PhoneOS + + +@ProcessorStepRegistry.register("map_phone_action_to_robot_action") +@dataclass +class MapPhoneActionToRobotAction(ActionProcessor): + """ + Map calibrated phone pose (actions) to the inputs for robot actions + + Expected input ACTION keys: + { + "action.phone.enabled": bool, + "action.phone.pos": np.ndarray, + "action.phone.rot": Rotation, + "action.phone.raw_inputs": dict, + } + + Output ACTION keys: + { + "action.enabled": bool, + "action.ee.{x,y,z,wx,wy,wz}" : float + "action.gripper": float, + } + """ + + platform: PhoneOS + _enabled_prev: bool = field(default=False, init=False, repr=False) + + def action(self, act: dict | None) -> dict: + # Pop them from the action + enabled = act.pop("action.phone.enabled", 0) + pos = act.pop("action.phone.pos", None) + rot = act.pop("action.phone.rot", None) + inputs = act.pop("action.phone.raw_inputs", {}) + + if pos is None or rot is None: + return act + + rotvec = rot.as_rotvec() # Absolute orientation as rotvec + + # Map certain inputs to certain actions + if self.platform == PhoneOS.IOS: + gripper = float(inputs.get("a3", 0.0)) + else: + a = float(inputs.get("reservedButtonA", 0.0)) + b = float(inputs.get("reservedButtonB", 0.0)) + gripper = ( + a - b + ) # Positive if a is pressed, negative if b is pressed, 0 if both or neither are pressed + + # For some actions we need to invert the axis + act.update( + { + "action.enabled": enabled, + "action.target_x": -pos[1] if enabled else 0.0, + "action.target_y": pos[0] if enabled else 0.0, + "action.target_z": pos[2] if enabled else 0.0, + "action.target_wx": rotvec[1] if enabled else 0.0, + "action.target_wy": rotvec[0] if enabled else 0.0, + "action.target_wz": -rotvec[2] if enabled else 0.0, + "action.gripper": gripper, # Still send gripper action when disabled + } + ) + return act + + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features diff --git a/src/lerobot/utils/visualization_utils.py b/src/lerobot/utils/visualization_utils.py index f0f9aebb7..8a4f65a03 100644 --- a/src/lerobot/utils/visualization_utils.py +++ b/src/lerobot/utils/visualization_utils.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numbers import os from typing import Any import numpy as np import rerun as rr +from lerobot.processor.pipeline import EnvTransition, TransitionKey + def _init_rerun(session_name: str = "lerobot_control_loop") -> None: """Initializes the Rerun SDK for visualizing the control loop.""" @@ -28,19 +31,87 @@ def _init_rerun(session_name: str = "lerobot_control_loop") -> None: rr.spawn(memory_limit=memory_limit) -def log_rerun_data(observation: dict[str | Any], action: dict[str | Any]): - for obs, val in observation.items(): - if isinstance(val, float): - rr.log(f"observation.{obs}", rr.Scalar(val)) - elif isinstance(val, np.ndarray): - if val.ndim == 1: - for i, v in enumerate(val): - rr.log(f"observation.{obs}_{i}", rr.Scalar(float(v))) +def _is_scalar(x): + return ( + isinstance(x, numbers.Real) + or isinstance(x, (np.integer, np.floating)) + or (isinstance(x, np.ndarray) and x.ndim == 0) + ) + + +def log_rerun_data( + data: list[dict[str | Any] | EnvTransition] | dict[str | Any] | EnvTransition | None = None, + *, + observation: dict[str, Any] | None = None, + action: dict[str, Any] | None = None, +) -> None: + items = data if isinstance(data, list) else ([data] if data is not None else []) + + obs = {} if observation is None else dict(observation) + act = {} if action is None else dict(action) + + for idx, item in enumerate(items): + if not isinstance(item, dict): + continue + + if any(isinstance(k, TransitionKey) for k in item.keys()): + o = item.get(TransitionKey.OBSERVATION) or {} + a = item.get(TransitionKey.ACTION) or {} + if isinstance(o, dict): + obs.update(o) + if isinstance(a, dict): + act.update(a) + continue + + keys = list(item.keys()) + has_obs = any(str(k).startswith("observation.") for k in keys) + has_act = any(str(k).startswith("action.") for k in keys) + + if has_obs or has_act: + if has_obs: + obs.update(item) + if has_act: + act.update(item) + else: + # No prefixes: assume first is observation, second is action, others are observation + if idx == 0: + obs.update(item) + elif idx == 1: + act.update(item) else: - rr.log(f"observation.{obs}", rr.Image(val), static=True) - for act, val in action.items(): - if isinstance(val, float): - rr.log(f"action.{act}", rr.Scalar(val)) - elif isinstance(val, np.ndarray): - for i, v in enumerate(val): - rr.log(f"action.{act}_{i}", rr.Scalar(float(v))) + obs.update(item) + + for k, v in obs.items(): + if v is None: + continue + key = k if str(k).startswith("observation.") else f"observation.{k}" + + if _is_scalar(v): + rr.log(key, rr.Scalar(float(v))) + elif isinstance(v, np.ndarray): + arr = v + # Convert CHW -> HWC when needed + if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4): + arr = np.transpose(arr, (1, 2, 0)) + if arr.ndim == 1: + for i, vi in enumerate(arr): + rr.log(f"{key}_{i}", rr.Scalar(float(vi))) + else: + rr.log(key, rr.Image(arr), static=True) + + for k, v in act.items(): + if v is None: + continue + key = k if str(k).startswith("action.") else f"action.{k}" + + if _is_scalar(v): + rr.log(key, rr.Scalar(float(v))) + elif isinstance(v, np.ndarray): + if v.ndim == 1: + for i, vi in enumerate(v): + rr.log(f"{key}_{i}", rr.Scalar(float(vi))) + else: + # Fall back to flattening higher-dimensional arrays + flat = v.flatten() + for i, vi in enumerate(flat): + rr.log(f"{key}_{i}", rr.Scalar(float(vi))) diff --git a/tests/datasets/test_dataset_utils.py b/tests/datasets/test_dataset_utils.py new file mode 100644 index 000000000..ae09fb262 --- /dev/null +++ b/tests/datasets/test_dataset_utils.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from datasets import Dataset +from huggingface_hub import DatasetCard + +from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index +from lerobot.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch, merge_features + + +def test_default_parameters(): + card = create_lerobot_dataset_card() + assert isinstance(card, DatasetCard) + assert card.data.tags == ["LeRobot"] + assert card.data.task_categories == ["robotics"] + assert card.data.configs == [ + { + "config_name": "default", + "data_files": "data/*/*.parquet", + } + ] + + +def test_with_tags(): + tags = ["tag1", "tag2"] + card = create_lerobot_dataset_card(tags=tags) + assert card.data.tags == ["LeRobot", "tag1", "tag2"] + + +def test_calculate_episode_data_index(): + dataset = Dataset.from_dict( + { + "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + "index": [0, 1, 2, 3, 4, 5], + "episode_index": [0, 0, 1, 2, 2, 2], + }, + ) + dataset.set_transform(hf_transform_to_torch) + episode_data_index = calculate_episode_data_index(dataset) + assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3])) + assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6])) + + +def test_merge_simple_vectors(): + g1 = { + "action": { + "dtype": "float32", + "shape": (2,), + "names": ["ee.x", "ee.y"], + } + } + g2 = { + "action": { + "dtype": "float32", + "shape": (2,), + "names": ["ee.y", "ee.z"], + } + } + + out = merge_features(g1, g2) + + assert "action" in out + assert out["action"]["dtype"] == "float32" + # Names merged with preserved order and de-dupuplication + assert out["action"]["names"] == ["ee.x", "ee.y", "ee.z"] + # Shape correctly recomputed from names length + assert out["action"]["shape"] == (3,) + + +def test_merge_multiple_groups_order_and_dedup(): + g1 = {"action": {"dtype": "float32", "shape": (2,), "names": ["a", "b"]}} + g2 = {"action": {"dtype": "float32", "shape": (2,), "names": ["b", "c"]}} + g3 = {"action": {"dtype": "float32", "shape": (3,), "names": ["a", "c", "d"]}} + + out = merge_features(g1, g2, g3) + + assert out["action"]["names"] == ["a", "b", "c", "d"] + assert out["action"]["shape"] == (4,) + + +def test_non_vector_last_wins_for_images(): + # Non-vector (images) with same name should be overwritten by the last image specified + g1 = { + "observation.images.front": { + "dtype": "image", + "shape": (3, 480, 640), + "names": ["channels", "height", "width"], + } + } + g2 = { + "observation.images.front": { + "dtype": "image", + "shape": (3, 720, 1280), + "names": ["channels", "height", "width"], + } + } + + out = merge_features(g1, g2) + assert out["observation.images.front"]["shape"] == (3, 720, 1280) + assert out["observation.images.front"]["dtype"] == "image" + + +def test_dtype_mismatch_raises(): + g1 = {"action": {"dtype": "float32", "shape": (1,), "names": ["a"]}} + g2 = {"action": {"dtype": "float64", "shape": (1,), "names": ["b"]}} + + with pytest.raises(ValueError, match="dtype mismatch for 'action'"): + _ = merge_features(g1, g2) + + +def test_non_dict_passthrough_last_wins(): + g1 = {"misc": 123} + g2 = {"misc": 456} + + out = merge_features(g1, g2) + # For non-dict entries the last one wins + assert out["misc"] == 456 diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py deleted file mode 100644 index ba16874d0..000000000 --- a/tests/datasets/test_utils.py +++ /dev/null @@ -1,55 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from datasets import Dataset -from huggingface_hub import DatasetCard - -from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index -from lerobot.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch - - -def test_default_parameters(): - card = create_lerobot_dataset_card() - assert isinstance(card, DatasetCard) - assert card.data.tags == ["LeRobot"] - assert card.data.task_categories == ["robotics"] - assert card.data.configs == [ - { - "config_name": "default", - "data_files": "data/*/*.parquet", - } - ] - - -def test_with_tags(): - tags = ["tag1", "tag2"] - card = create_lerobot_dataset_card(tags=tags) - assert card.data.tags == ["LeRobot", "tag1", "tag2"] - - -def test_calculate_episode_data_index(): - dataset = Dataset.from_dict( - { - "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - "index": [0, 1, 2, 3, 4, 5], - "episode_index": [0, 0, 1, 2, 2, 2], - }, - ) - dataset.set_transform(hf_transform_to_torch) - episode_data_index = calculate_episode_data_index(dataset) - assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3])) - assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6])) diff --git a/tests/processor/test_converters.py b/tests/processor/test_converters.py new file mode 100644 index 000000000..590f6a892 --- /dev/null +++ b/tests/processor/test_converters.py @@ -0,0 +1,196 @@ +import numpy as np +import pytest +import torch + +from lerobot.processor.converters import ( + to_dataset_frame, + to_output_robot_action, + to_transition_robot_observation, + to_transition_teleop_action, +) +from lerobot.processor.pipeline import TransitionKey + + +def test_to_transition_teleop_action_prefix_and_tensor_conversion(): + # Scalars, arrays, and "image-like" uint8 arrays are supported + img = np.zeros((8, 12, 3), dtype=np.uint8) + act = { + "ee.x": 0.5, # scalar to torch tensor + "delta": np.array([1.0, 2.0]), # ndarray to torch tensor + "raw_img": img, # uint8 HWC to passthrough ndarray + } + + tr = to_transition_teleop_action(act) + + # Should be an EnvTransition-like dict with ACTION populated + assert isinstance(tr, dict) + assert TransitionKey.ACTION in tr + assert "action.ee.x" in tr[TransitionKey.ACTION] + assert "action.delta" in tr[TransitionKey.ACTION] + assert "action.raw_img" in tr[TransitionKey.ACTION] + + # Types: scalars/arrays -> torch tensor; images to np.ndarray + assert isinstance(tr[TransitionKey.ACTION]["action.ee.x"], torch.Tensor) + assert tr[TransitionKey.ACTION]["action.ee.x"].item() == pytest.approx(0.5) + + assert isinstance(tr[TransitionKey.ACTION]["action.delta"], torch.Tensor) + assert tr[TransitionKey.ACTION]["action.delta"].shape == (2,) + assert torch.allclose(tr[TransitionKey.ACTION]["action.delta"], torch.tensor([1.0, 2.0])) + + assert isinstance(tr[TransitionKey.ACTION]["action.raw_img"], np.ndarray) + assert tr[TransitionKey.ACTION]["action.raw_img"].dtype == np.uint8 + assert tr[TransitionKey.ACTION]["action.raw_img"].shape == (8, 12, 3) + + # Observation is created as empty dict by make_transition + assert TransitionKey.OBSERVATION in tr + assert isinstance(tr[TransitionKey.OBSERVATION], dict) + assert tr[TransitionKey.OBSERVATION] == {} + + +def test_to_transition_robot_observation_state_vs_images_split(): + # Create an observation with mixed content + img = np.full((10, 20, 3), 255, dtype=np.uint8) # image (uint8 HWC) + obs = { + "j1.pos": 10.0, # scalar to state to torch tensor + "j2.pos": np.float32(20.0), # scalar np to state to torch tensor + "image_front": img, # to images passthrough + "flag": np.int32(7), # scalar to state to torch tensor + "arr": np.array([1.5, 2.5]), # vector to state to torch tensor + } + + tr = to_transition_robot_observation(obs) + assert isinstance(tr, dict) + assert TransitionKey.OBSERVATION in tr + + out = tr[TransitionKey.OBSERVATION] + # Check state keys are present and converted to tensors + for k in ("j1.pos", "j2.pos", "flag", "arr"): + key = f"observation.state.{k}" + assert key in out + v = out[key] + if k != "arr": + assert isinstance(v, torch.Tensor) and v.ndim == 0 + else: + assert isinstance(v, torch.Tensor) and v.ndim == 1 and v.shape == (2,) + + # Check image present as is + assert "observation.images.image_front" in out + assert isinstance(out["observation.images.image_front"], np.ndarray) + assert out["observation.images.image_front"].dtype == np.uint8 + assert out["observation.images.image_front"].shape == (10, 20, 3) + + # ACTION should be empty dict by make_transition + assert TransitionKey.ACTION in tr + assert isinstance(tr[TransitionKey.ACTION], dict) + assert tr[TransitionKey.ACTION] == {} + + +def test_to_output_robot_action_strips_prefix_and_filters_pos_keys_only(): + # Build a transition with mixed action keys + tr = { + TransitionKey.ACTION: { + "action.j1.pos": 11.0, # keep "j1.pos" + "action.gripper.pos": torch.tensor(33.0), # keep: tensor accepted + "action.ee.x": 0.5, # ignore (doesn't end with .pos) + "misc": "ignore_me", # ignore (no 'action.' prefix) + } + } + + out = to_output_robot_action(tr) + # Only ".pos" keys with "action." prefix are retained and stripped to base names + assert set(out.keys()) == {"j1.pos", "gripper.pos"} + # Values converted to float + assert isinstance(out["j1.pos"], float) + assert isinstance(out["gripper.pos"], float) + assert out["j1.pos"] == pytest.approx(11.0) + assert out["gripper.pos"] == pytest.approx(33.0) + + +def test_to_dataset_frame_merge_and_pack_vectors_and_metadata(): + # Fabricate dataset features (as stored in dataset.meta["features"]) + features = { + # Action vector: 3 elements in specific order + "action": { + "dtype": "float32", + "shape": (3,), + "names": ["j1.pos", "j2.pos", "gripper.pos"], + }, + # Observation state vector: 2 elements + "observation.state": { + "dtype": "float32", + "shape": (2,), + "names": ["j1.pos", "j2.pos"], + }, + # Image spec (video/image dtype acceptable) + "observation.images.front": { + "dtype": "image", + "shape": (480, 640, 3), + "names": ["h", "w", "c"], + }, + } + + # Build two transitions to be merged: teleop (action) and robot obs (state/images) + img = np.random.randint(0, 255, size=(480, 640, 3), dtype=np.uint8) + + teleop_transition = { + TransitionKey.OBSERVATION: {}, + TransitionKey.ACTION: { + "action.j1.pos": torch.tensor(1.1), + "action.j2.pos": torch.tensor(2.2), + # gripper.pos missing → defaults to 0.0 + "action.ee.x": 0.5, # ignored, not in features["action"]["names"] + }, + TransitionKey.COMPLEMENTARY_DATA: { + "frame_is_pad": True, + "task": "Pick cube", + }, + } + + robot_transition = { + TransitionKey.OBSERVATION: { + "observation.state.j1.pos": torch.tensor(10.0), + "observation.state.j2.pos": torch.tensor(20.0), + "observation.images.front": img, + }, + TransitionKey.REWARD: torch.tensor(5.0), + TransitionKey.DONE: True, + TransitionKey.TRUNCATED: False, + TransitionKey.INFO: {"note": "ok"}, + } + + # Directly call the refactored function + batch = to_dataset_frame([teleop_transition, robot_transition], features) + + # Images passthrough + assert "observation.images.front" in batch + assert batch["observation.images.front"].shape == img.shape + assert batch["observation.images.front"].dtype == np.uint8 + assert np.shares_memory(batch["observation.images.front"], img) or np.array_equal( + batch["observation.images.front"], img + ) + + # Observation.state vector + assert "observation.state" in batch + obs_vec = batch["observation.state"] + assert isinstance(obs_vec, np.ndarray) and obs_vec.dtype == np.float32 + assert obs_vec.shape == (2,) + assert obs_vec[0] == pytest.approx(10.0) + assert obs_vec[1] == pytest.approx(20.0) + + # Action vector + assert "action" in batch + act_vec = batch["action"] + assert isinstance(act_vec, np.ndarray) and act_vec.dtype == np.float32 + assert act_vec.shape == (3,) + assert act_vec[0] == pytest.approx(1.1) + assert act_vec[1] == pytest.approx(2.2) + assert act_vec[2] == pytest.approx(0.0) # default for missing gripper.pos + + # Next.* metadata + assert batch["next.reward"] == pytest.approx(5.0) + assert batch["next.done"] is True + assert batch["next.truncated"] is False + + # Complementary data + assert batch["frame_is_pad"] is True + assert batch["task"] == "Pick cube" diff --git a/tests/processor/test_device_processor.py b/tests/processor/test_device_processor.py index bda120015..7e30750f4 100644 --- a/tests/processor/test_device_processor.py +++ b/tests/processor/test_device_processor.py @@ -288,8 +288,8 @@ def test_serialization_methods(): assert processor.device == device -def test_feature_contract(): - """Test that feature_contract returns features unchanged.""" +def test_features(): + """Test that features returns features unchanged.""" processor = DeviceProcessor(device="cpu") features = { @@ -297,7 +297,7 @@ def test_feature_contract(): "action": PolicyFeature(type=FeatureType.ACTION, shape=(5,)), } - result = processor.feature_contract(features) + result = processor.transform_features(features) assert result == features assert result is features # Should return the same object diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py index 6fc60b49b..97c737e0c 100644 --- a/tests/processor/test_normalize_processor.py +++ b/tests/processor/test_normalize_processor.py @@ -621,10 +621,19 @@ def test_serialization_roundtrip(full_stats): assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION]) # Verify features and norm_map are correctly reconstructed - assert new_processor.features.keys() == original_processor.features.keys() - for key in new_processor.features: - assert new_processor.features[key].type == original_processor.features[key].type - assert new_processor.features[key].shape == original_processor.features[key].shape + assert ( + new_processor.transform_features(features).keys() + == original_processor.transform_features(features).keys() + ) + for key in new_processor.transform_features(features): + assert ( + new_processor.transform_features(features)[key].type + == original_processor.transform_features(features)[key].type + ) + assert ( + new_processor.transform_features(features)[key].shape + == original_processor.transform_features(features)[key].shape + ) assert new_processor.norm_map == original_processor.norm_map diff --git a/tests/processor/test_observation_processor.py b/tests/processor/test_observation_processor.py index e48b6bc08..4e6efdb6c 100644 --- a/tests/processor/test_observation_processor.py +++ b/tests/processor/test_observation_processor.py @@ -410,13 +410,13 @@ def test_equivalent_with_image_dict(): torch.testing.assert_close(original_result[key], processor_result[key]) -def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory): +def test_image_processor_features_pixels_to_image(policy_feature_factory): processor = VanillaObservationProcessor() features = { "pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), "keep": policy_feature_factory(FeatureType.ENV, (1,)), } - out = processor.feature_contract(features.copy()) + out = processor.transform_features(features.copy()) assert OBS_IMAGE in out and out[OBS_IMAGE] == features["pixels"] assert "pixels" not in out @@ -424,13 +424,13 @@ def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory assert_contract_is_typed(out) -def test_image_processor_feature_contract_observation_pixels_to_image(policy_feature_factory): +def test_image_processor_features_observation_pixels_to_image(policy_feature_factory): processor = VanillaObservationProcessor() features = { "observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), "keep": policy_feature_factory(FeatureType.ENV, (1,)), } - out = processor.feature_contract(features.copy()) + out = processor.transform_features(features.copy()) assert OBS_IMAGE in out and out[OBS_IMAGE] == features["observation.pixels"] assert "observation.pixels" not in out @@ -438,7 +438,7 @@ def test_image_processor_feature_contract_observation_pixels_to_image(policy_fea assert_contract_is_typed(out) -def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_feature_factory): +def test_image_processor_features_multi_camera_and_prefixed(policy_feature_factory): processor = VanillaObservationProcessor() features = { "pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), @@ -446,7 +446,7 @@ def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_featu "observation.pixels.rear": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), "keep": policy_feature_factory(FeatureType.ENV, (7,)), } - out = processor.feature_contract(features.copy()) + out = processor.transform_features(features.copy()) assert f"{OBS_IMAGES}.front" in out and out[f"{OBS_IMAGES}.front"] == features["pixels.front"] assert f"{OBS_IMAGES}.wrist" in out and out[f"{OBS_IMAGES}.wrist"] == features["pixels.wrist"] @@ -456,14 +456,14 @@ def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_featu assert_contract_is_typed(out) -def test_state_processor_feature_contract_environment_and_agent_pos(policy_feature_factory): +def test_state_processor_features_environment_and_agent_pos(policy_feature_factory): processor = VanillaObservationProcessor() features = { "environment_state": policy_feature_factory(FeatureType.STATE, (3,)), "agent_pos": policy_feature_factory(FeatureType.STATE, (7,)), "keep": policy_feature_factory(FeatureType.ENV, (1,)), } - out = processor.feature_contract(features.copy()) + out = processor.transform_features(features.copy()) assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["environment_state"] assert OBS_STATE in out and out[OBS_STATE] == features["agent_pos"] @@ -472,13 +472,13 @@ def test_state_processor_feature_contract_environment_and_agent_pos(policy_featu assert_contract_is_typed(out) -def test_state_processor_feature_contract_prefixed_inputs(policy_feature_factory): +def test_state_processor_features_prefixed_inputs(policy_feature_factory): proc = VanillaObservationProcessor() features = { "observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)), "observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)), } - out = proc.feature_contract(features.copy()) + out = proc.transform_features(features.copy()) assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["observation.environment_state"] assert OBS_STATE in out and out[OBS_STATE] == features["observation.agent_pos"] diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index 26e865fad..42a8eb538 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -26,6 +26,7 @@ import torch import torch.nn as nn from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features from lerobot.processor import EnvTransition, ProcessorStepRegistry, RobotProcessor from lerobot.processor.pipeline import TransitionKey from tests.conftest import assert_contract_is_typed @@ -90,8 +91,8 @@ class MockStep: def reset(self) -> None: self.counter = 0 - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test features here return features @@ -112,8 +113,8 @@ class MockStepWithoutOptionalMethods: return transition - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test features here return features @@ -168,8 +169,8 @@ class MockStepWithTensorState: self.running_mean.zero_() self.running_count.zero_() - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test features here return features @@ -662,8 +663,8 @@ class MockModuleStep(nn.Module): self.running_mean.zero_() self.counter = 0 - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test features here return features @@ -744,8 +745,8 @@ class MockNonModuleStepWithState: self.step_count.zero_() self.history.clear() - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test features here return features @@ -799,8 +800,8 @@ class MockStepWithNonSerializableParam: def reset(self) -> None: pass - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test features here return features @@ -838,8 +839,8 @@ class RegisteredMockStep: def reset(self) -> None: pass - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test features here return features @@ -1382,8 +1383,8 @@ def test_state_file_naming_with_registry(): def load_state_dict(self, state): self.state_tensor = state["state_tensor"] - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test features here return features try: @@ -1439,8 +1440,8 @@ def test_override_with_nested_config(): def get_config(self): return {"name": self.name, "simple_param": self.simple_param, "nested_config": self.nested_config} - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test features here return features try: @@ -1531,8 +1532,8 @@ def test_override_with_callables(): def get_config(self): return {"name": self.name} - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test features here return features try: @@ -1766,8 +1767,8 @@ def test_override_with_device_strings(): def load_state_dict(self, state): self.buffer = state["buffer"] - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test features here return features try: @@ -1860,21 +1861,16 @@ def test_save_load_with_custom_converter_functions(): class NonCompliantStep: - """Intentionally non-compliant: missing feature_contract.""" + """Intentionally non-compliant: missing features.""" def __call__(self, transition: EnvTransition) -> EnvTransition: return transition -def test_construction_rejects_step_without_feature_contract(): - with pytest.raises(TypeError, match=r"must define feature_contract\(features\) -> dict\[str, Any\]"): - RobotProcessor([NonCompliantStep()]) - - class NonCallableStep: """Intentionally non-compliant: missing __call__.""" - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: return features @@ -1893,7 +1889,7 @@ class FeatureContractAddStep: def __call__(self, transition: EnvTransition) -> EnvTransition: return transition - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: features[self.key] = self.value return features @@ -1908,7 +1904,7 @@ class FeatureContractMutateStep: def __call__(self, transition: EnvTransition) -> EnvTransition: return transition - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: features[self.key] = self.fn(features.get(self.key)) return features @@ -1920,7 +1916,7 @@ class FeatureContractBadReturnStep: def __call__(self, transition: EnvTransition) -> EnvTransition: return transition - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: return ["not-a-dict"] @@ -1933,12 +1929,12 @@ class FeatureContractRemoveStep: def __call__(self, transition: EnvTransition) -> EnvTransition: return transition - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: features.pop(self.key, None) return features -def test_feature_contract_orders_and_merges(policy_feature_factory): +def test_features_orders_and_merges(policy_feature_factory): p = RobotProcessor( [ FeatureContractAddStep("a", policy_feature_factory(FeatureType.STATE, (1,))), @@ -1946,14 +1942,14 @@ def test_feature_contract_orders_and_merges(policy_feature_factory): FeatureContractAddStep("b", policy_feature_factory(FeatureType.ENV, (2,))), ] ) - out = p.feature_contract({}) + out = p.transform_features({}) assert out["a"].type == FeatureType.STATE and out["a"].shape == (3,) assert out["b"].type == FeatureType.ENV and out["b"].shape == (2,) assert_contract_is_typed(out) -def test_feature_contract_respects_initial_without_mutation(policy_feature_factory): +def test_features_respects_initial_without_mutation(policy_feature_factory): initial = { "seed": policy_feature_factory(FeatureType.STATE, (7,)), "nested": policy_feature_factory(FeatureType.ENV, (0,)), @@ -1966,7 +1962,7 @@ def test_feature_contract_respects_initial_without_mutation(policy_feature_facto ), ] ) - out = p.feature_contract(initial_features=initial) + out = p.transform_features(initial_features=initial) assert out["seed"].shape == (8,) assert out["nested"].shape == (5,) @@ -1977,13 +1973,7 @@ def test_feature_contract_respects_initial_without_mutation(policy_feature_facto assert_contract_is_typed(out) -def test_feature_contract_type_error_on_bad_step(): - p = RobotProcessor([FeatureContractAddStep(), FeatureContractBadReturnStep()]) - with pytest.raises(TypeError, match=r"\w+\.feature_contract must return dict\[str, Any\]"): - _ = p.feature_contract({}) - - -def test_feature_contract_execution_order_tracking(): +def test_features_execution_order_tracking(): class Track: def __init__(self, label): self.label = label @@ -1991,32 +1981,186 @@ def test_feature_contract_execution_order_tracking(): def __call__(self, transition: EnvTransition) -> EnvTransition: return transition - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: code = {"A": 1, "B": 2, "C": 3}[self.label] pf = features.get("order", PolicyFeature(type=FeatureType.ENV, shape=())) features["order"] = PolicyFeature(type=pf.type, shape=pf.shape + (code,)) return features - out = RobotProcessor([Track("A"), Track("B"), Track("C")]).feature_contract({}) + out = RobotProcessor([Track("A"), Track("B"), Track("C")]).transform_features({}) assert out["order"].shape == (1, 2, 3) -def test_feature_contract_remove_key(policy_feature_factory): +def test_features_remove_key(policy_feature_factory): p = RobotProcessor( [ FeatureContractAddStep("a", policy_feature_factory(FeatureType.STATE, (1,))), FeatureContractRemoveStep("a"), ] ) - out = p.feature_contract({}) + out = p.transform_features({}) assert "a" not in out -def test_feature_contract_remove_from_initial(policy_feature_factory): +def test_features_remove_from_initial(policy_feature_factory): initial = { "keep": policy_feature_factory(FeatureType.STATE, (1,)), "drop": policy_feature_factory(FeatureType.STATE, (1,)), } p = RobotProcessor([FeatureContractRemoveStep("drop")]) - out = p.feature_contract(initial_features=initial) + out = p.transform_features(initial_features=initial) assert "drop" not in out and out["keep"] == initial["keep"] + + +@dataclass +class AddActionEEAndJointFeatures: + """Adds both EE and JOINT action features.""" + + def __call__(self, tr): + return tr + + def transform_features(self, features: dict) -> dict: + # EE features + features["action.ee.x"] = float + features["action.ee.y"] = float + # JOINT features + features["action.j1.pos"] = float + features["action.j2.pos"] = float + return features + + +@dataclass +class AddObservationStateFeatures: + """Adds state features (and optionally an image spec to test precedence).""" + + add_front_image: bool = False + front_image_shape: tuple = (240, 320, 3) + + def __call__(self, tr): + return tr + + def transform_features(self, features: dict) -> dict: + # State features (mix EE and a joint state) + features["observation.state.ee.x"] = float + features["observation.state.j1.pos"] = float + if self.add_front_image: + features["observation.images.front"] = self.front_image_shape + return features + + +def test_aggregate_joint_action_only(): + rp = RobotProcessor([AddActionEEAndJointFeatures()]) + initial = {"front": (480, 640, 3)} + + out = aggregate_pipeline_dataset_features( + pipeline=rp, + initial_features=initial, + use_videos=True, + patterns=["action.j1.pos", "action.j2.pos"], + ) + + # Expect only "action" with joint names + assert "action" in out and "observation.state" not in out + assert out["action"]["dtype"] == "float32" + assert set(out["action"]["names"]) == {"j1.pos", "j2.pos"} + assert out["action"]["shape"] == (len(out["action"]["names"]),) + + +def test_aggregate_ee_action_and_observation_with_videos(): + rp = RobotProcessor([AddActionEEAndJointFeatures(), AddObservationStateFeatures()]) + initial = {"front": (480, 640, 3), "side": (720, 1280, 3)} + + out = aggregate_pipeline_dataset_features( + pipeline=rp, + initial_features=initial, + use_videos=True, + patterns=["action.ee", "observation.state"], + ) + + # Action should pack only EE names + assert "action" in out + assert set(out["action"]["names"]) == {"ee.x", "ee.y"} + assert out["action"]["dtype"] == "float32" + + # Observation state should pack both ee.x and j1.pos as a vector + assert "observation.state" in out + assert set(out["observation.state"]["names"]) == {"ee.x", "j1.pos"} + assert out["observation.state"]["dtype"] == "float32" + + # Cameras from initial_features appear as videos + for cam in ("front", "side"): + key = f"observation.images.{cam}" + assert key in out + assert out[key]["dtype"] == "video" + assert out[key]["shape"] == initial[cam] + assert out[key]["names"] == ["height", "width", "channels"] + + +def test_aggregate_both_action_types(): + rp = RobotProcessor([AddActionEEAndJointFeatures()]) + out = aggregate_pipeline_dataset_features( + pipeline=rp, + initial_features={}, + use_videos=True, + patterns=["action.ee", "action.j1", "action.j2.pos"], + ) + + assert "action" in out + expected = {"ee.x", "ee.y", "j1.pos", "j2.pos"} + assert set(out["action"]["names"]) == expected + assert out["action"]["shape"] == (len(expected),) + + +def test_aggregate_images_when_use_videos_false(): + rp = RobotProcessor([AddObservationStateFeatures(add_front_image=True)]) + initial = {"back": (480, 640, 3)} + + out = aggregate_pipeline_dataset_features( + pipeline=rp, + initial_features=initial, + use_videos=False, # expect "image" dtype + patterns=None, + ) + + key = "observation.images.back" + key_front = "observation.images.front" + assert key not in out + assert key_front not in out + + +def test_aggregate_images_when_use_videos_true(): + rp = RobotProcessor([AddObservationStateFeatures(add_front_image=True)]) + initial = {"back": (480, 640, 3)} + + out = aggregate_pipeline_dataset_features( + pipeline=rp, + initial_features=initial, + use_videos=True, + patterns=None, + ) + + key = "observation.images.front" + key_back = "observation.images.back" + assert key in out + assert key_back in out + assert out[key]["dtype"] == "video" + assert out[key_back]["dtype"] == "video" + assert out[key_back]["shape"] == initial["back"] + + +def test_initial_camera_not_overridden_by_step_image(): + # Step explicitly sets a different front image shape; initial has another shape. + # aggregate_pipeline_dataset_features should keep the step's value (setdefault behavior on initial cams). + rp = RobotProcessor([AddObservationStateFeatures(add_front_image=True, front_image_shape=(240, 320, 3))]) + initial = {"front": (480, 640, 3)} # should NOT override the step-provided (240, 320, 3) + + out = aggregate_pipeline_dataset_features( + pipeline=rp, + initial_features=initial, + use_videos=True, + patterns=["observation.images.front"], + ) + + key = "observation.images.front" + assert key in out + assert out[key]["shape"] == (240, 320, 3) # from the step, not from initial diff --git a/tests/processor/test_rename_processor.py b/tests/processor/test_rename_processor.py index 229d57f9f..398b3ec9c 100644 --- a/tests/processor/test_rename_processor.py +++ b/tests/processor/test_rename_processor.py @@ -410,7 +410,7 @@ def test_value_types_preserved(): assert processed_obs["old_list"] == [1, 2, 3] -def test_feature_contract_basic_renaming(policy_feature_factory): +def test_features_basic_renaming(policy_feature_factory): processor = RenameProcessor(rename_map={"a": "x", "b": "y"}) features = { "a": policy_feature_factory(FeatureType.STATE, (2,)), @@ -418,7 +418,7 @@ def test_feature_contract_basic_renaming(policy_feature_factory): "c": policy_feature_factory(FeatureType.ENV, (1,)), } - out = processor.feature_contract(features.copy()) + out = processor.transform_features(features.copy()) # Values preserved and typed assert out["x"] == features["a"] @@ -430,14 +430,14 @@ def test_feature_contract_basic_renaming(policy_feature_factory): assert set(features) == {"a", "b", "c"} -def test_feature_contract_overlapping_keys(policy_feature_factory): +def test_features_overlapping_keys(policy_feature_factory): # Overlapping renames: both 'a' and 'b' exist. 'a'->'b', 'b'->'c' processor = RenameProcessor(rename_map={"a": "b", "b": "c"}) features = { "a": policy_feature_factory(FeatureType.STATE, (1,)), "b": policy_feature_factory(FeatureType.STATE, (2,)), } - out = processor.feature_contract(features) + out = processor.transform_features(features) assert set(out) == {"b", "c"} assert out["b"] == features["a"] # 'a' renamed to'b' @@ -445,7 +445,7 @@ def test_feature_contract_overlapping_keys(policy_feature_factory): assert_contract_is_typed(out) -def test_feature_contract_chained_processors(policy_feature_factory): +def test_features_chained_processors(policy_feature_factory): # Chain two rename processors at the contract level processor1 = RenameProcessor(rename_map={"pos": "agent_position", "img": "camera_image"}) processor2 = RenameProcessor( @@ -458,7 +458,7 @@ def test_feature_contract_chained_processors(policy_feature_factory): "img": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), "extra": policy_feature_factory(FeatureType.ENV, (1,)), } - out = pipeline.feature_contract(initial_features=spec) + out = pipeline.transform_features(initial_features=spec) assert set(out) == {"observation.state", "observation.image", "extra"} assert out["observation.state"] == spec["pos"] diff --git a/tests/processor/test_tokenizer_processor.py b/tests/processor/test_tokenizer_processor.py index 452c36da9..784b1ce81 100644 --- a/tests/processor/test_tokenizer_processor.py +++ b/tests/processor/test_tokenizer_processor.py @@ -470,7 +470,7 @@ def test_registry_functionality(): @require_package("transformers") -def test_feature_contract_basic(): +def test_features_basic(): """Test basic feature contract functionality.""" mock_tokenizer = MockTokenizer(vocab_size=100) processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=128) @@ -480,7 +480,7 @@ def test_feature_contract_basic(): "action": PolicyFeature(type=FeatureType.ACTION, shape=(5,)), } - output_features = processor.feature_contract(input_features) + output_features = processor.transform_features(input_features) # Check that original features are preserved assert "observation.state" in output_features @@ -501,13 +501,13 @@ def test_feature_contract_basic(): @require_package("transformers") -def test_feature_contract_with_custom_max_length(): +def test_features_with_custom_max_length(): """Test feature contract with custom max_length.""" mock_tokenizer = MockTokenizer(vocab_size=100) processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=64) input_features = {} - output_features = processor.feature_contract(input_features) + output_features = processor.transform_features(input_features) # Check that features use correct max_length assert f"{OBS_LANGUAGE}.tokens" in output_features @@ -521,7 +521,7 @@ def test_feature_contract_with_custom_max_length(): @require_package("transformers") -def test_feature_contract_existing_features(): +def test_features_existing_features(): """Test feature contract when tokenized features already exist.""" mock_tokenizer = MockTokenizer(vocab_size=100) processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=256) @@ -531,7 +531,7 @@ def test_feature_contract_existing_features(): f"{OBS_LANGUAGE}.attention_mask": PolicyFeature(type=FeatureType.LANGUAGE, shape=(100,)), } - output_features = processor.feature_contract(input_features) + output_features = processor.transform_features(input_features) # Should not overwrite existing features assert output_features[f"{OBS_LANGUAGE}.tokens"].shape == (100,) # Original shape preserved diff --git a/tests/utils/test_visualization_utils.py b/tests/utils/test_visualization_utils.py new file mode 100644 index 000000000..5e1eb4bab --- /dev/null +++ b/tests/utils/test_visualization_utils.py @@ -0,0 +1,205 @@ +import importlib +import sys +from types import SimpleNamespace + +import numpy as np +import pytest + +from lerobot.processor.pipeline import TransitionKey + + +@pytest.fixture +def mock_rerun(monkeypatch): + """ + Provide a mock `rerun` module so tests don't depend on the real library. + Also reload the module-under-test so it binds to this mock `rr`. + """ + calls = [] + + class DummyScalar: + def __init__(self, value): + self.value = float(value) + + class DummyImage: + def __init__(self, arr): + self.arr = arr + + def dummy_log(key, obj, **kwargs): + calls.append((key, obj, kwargs)) + + dummy_rr = SimpleNamespace( + Scalar=DummyScalar, + Image=DummyImage, + log=dummy_log, + init=lambda *a, **k: None, + spawn=lambda *a, **k: None, + ) + + # Inject fake module into sys.modules + monkeypatch.setitem(sys.modules, "rerun", dummy_rr) + + # Now import and reload the module under test, to bind to our rerun mock + import lerobot.utils.visualization_utils as vu + + importlib.reload(vu) + + # Expose both the reloaded module and the call recorder + yield vu, calls + + +def _keys(calls): + """Helper to extract just the keys logged to rr.log""" + return [k for (k, _obj, _kw) in calls] + + +def _obj_for(calls, key): + """Find the first object logged under a given key.""" + for k, obj, _kw in calls: + if k == key: + return obj + raise KeyError(f"Key {key} not found in calls: {calls}") + + +def _kwargs_for(calls, key): + for k, _obj, kw in calls: + if k == key: + return kw + raise KeyError(f"Key {key} not found in calls: {calls}") + + +def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun): + vu, calls = mock_rerun + + # Build EnvTransition dict + obs = { + "observation.state.temperature": np.float32(25.0), + # CHW image should be converted to HWC for rr.Image + "observation.camera": np.zeros((3, 10, 20), dtype=np.uint8), + } + act = { + "action.throttle": 0.7, + # 1D array should log individual Scalars with suffix _i + "action.vector": np.array([1.0, 2.0], dtype=np.float32), + } + transition = { + TransitionKey.OBSERVATION: obs, + TransitionKey.ACTION: act, + } + + vu.log_rerun_data(transition) + + # We expect: + # - observation.state.temperature -> Scalar + # - observation.camera -> Image (HWC) with static=True + # - action.throttle -> Scalar + # - action.vector_0, action.vector_1 -> Scalars + expected_keys = { + "observation.state.temperature", + "observation.camera", + "action.throttle", + "action.vector_0", + "action.vector_1", + } + assert set(_keys(calls)) == expected_keys + + # Check scalar types and values + temp_obj = _obj_for(calls, "observation.state.temperature") + assert type(temp_obj).__name__ == "DummyScalar" + assert temp_obj.value == pytest.approx(25.0) + + throttle_obj = _obj_for(calls, "action.throttle") + assert type(throttle_obj).__name__ == "DummyScalar" + assert throttle_obj.value == pytest.approx(0.7) + + v0 = _obj_for(calls, "action.vector_0") + v1 = _obj_for(calls, "action.vector_1") + assert type(v0).__name__ == "DummyScalar" + assert type(v1).__name__ == "DummyScalar" + assert v0.value == pytest.approx(1.0) + assert v1.value == pytest.approx(2.0) + + # Check image handling: CHW -> HWC + img_obj = _obj_for(calls, "observation.camera") + assert type(img_obj).__name__ == "DummyImage" + assert img_obj.arr.shape == (10, 20, 3) # transposed + assert _kwargs_for(calls, "observation.camera").get("static", False) is True # static=True for images + + +def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun): + vu, calls = mock_rerun + + # First dict without prefixes treated as observation + # Second dict without prefixes treated as action + obs_plain = { + "temp": 1.5, + # Already HWC image => should stay as-is + "img": np.zeros((5, 6, 3), dtype=np.uint8), + "none": None, # should be skipped + } + act_plain = { + "throttle": 0.3, + "vec": np.array([9, 8, 7], dtype=np.float32), + } + + vu.log_rerun_data([obs_plain, act_plain]) + + # Expected keys with auto-prefixes + expected = { + "observation.temp", + "observation.img", + "action.throttle", + "action.vec_0", + "action.vec_1", + "action.vec_2", + } + logged = set(_keys(calls)) + assert logged == expected + + # Scalars + t = _obj_for(calls, "observation.temp") + assert type(t).__name__ == "DummyScalar" + assert t.value == pytest.approx(1.5) + + throttle = _obj_for(calls, "action.throttle") + assert type(throttle).__name__ == "DummyScalar" + assert throttle.value == pytest.approx(0.3) + + # Image stays HWC + img = _obj_for(calls, "observation.img") + assert type(img).__name__ == "DummyImage" + assert img.arr.shape == (5, 6, 3) + assert _kwargs_for(calls, "observation.img").get("static", False) is True + + # Vectors + for i, val in enumerate([9, 8, 7]): + o = _obj_for(calls, f"action.vec_{i}") + assert type(o).__name__ == "DummyScalar" + assert o.value == pytest.approx(val) + + +def test_log_rerun_data_kwargs_only(mock_rerun): + vu, calls = mock_rerun + + vu.log_rerun_data( + None, + observation={"observation.temp": 10.0, "observation.gray": np.zeros((8, 8, 1), dtype=np.uint8)}, + action={"action.a": 1.0}, + ) + + keys = set(_keys(calls)) + assert "observation.temp" in keys + assert "observation.gray" in keys + assert "action.a" in keys + + temp = _obj_for(calls, "observation.temp") + assert type(temp).__name__ == "DummyScalar" + assert temp.value == pytest.approx(10.0) + + img = _obj_for(calls, "observation.gray") + assert type(img).__name__ == "DummyImage" + assert img.arr.shape == (8, 8, 1) # remains HWC + assert _kwargs_for(calls, "observation.gray").get("static", False) is True + + a = _obj_for(calls, "action.a") + assert type(a).__name__ == "DummyScalar" + assert a.value == pytest.approx(1.0)