mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-01 11:21:27 +00:00
feat(processor): multiple improvements to the pipeline porting (#1749)
* [Port codebase pipeline] General fixes for RL and scripts (#1748) * Refactor dataset configuration in documentation and codebase - Updated dataset configuration keys from `dataset_root` to `root` and `num_episodes` to `num_episodes_to_record` for consistency. - Adjusted replay episode handling by renaming `episode` to `replay_episode`. - Enhanced documentation - added specific processor to transform from policy actions to delta actions * Added Robot action to tensor processor Added new processor script for dealing with gym specific action processing * removed RobotAction2Tensor processor; imrpoved choosing observations in actor * nit in delta action * added missing reset functions to kinematics * Adapt teleoperate and replay to pipeline similar to record * refactor(processors): move to inheritance (#1750) * fix(teleoperator): improvements phone implementation (#1752) * fix(teleoperator): protect shared state in phone implementation * refactor(teleop): separate classes in phone * fix: solve breaking changes (#1753) * refactor(policies): multiple improvements (#1754) * refactor(processor): simpler logic in device processor (#1755) * refactor(processor): euclidean distance in delta action processor (#1757) * refactor(processor): improvements to joint observations processor migration (#1758) * refactor(processor): improvements to tokenizer migration (#1759) * refactor(processor): improvements to tokenizer migration * fix(tests): tokenizer tests regression from #1750 * fix(processors): fix float comparison and config in hil processors (#1760) * chore(teleop): remove unnecessary callbacks in KeyboardEndEffectorTeleop (#1761) * refactor(processor): improvements normalize pipeline migration (#1756) * refactor(processor): several improvements normalize processor step * refactor(processor): more improvements normalize processor * refactor(processor): more changes to normalizer * refactor(processor): take a different approach to DRY * refactor(processor): final design * chore(record): revert comment and continue deleted (#1764) * refactor(examples): pipeline phone examples (#1769) * refactor(examples): phone teleop + teleop script * refactor(examples): phone replay + replay * chore(examples): rename phone example files & folders * feat(processor): fix improvements to the pipeline porting (#1796) * refactor(processor): enhance tensor device handling in normalization process (#1795) * refactor(tests): remove unsupported device detection test for complementary data (#1797) * chore(tests): update ToBatchProcessor test (#1798) * refactor(tests): remove in-place mutation tests for actions and complementary data in batch processor * test(tests): add tests for action and task processing in batch processor * add names for android and ios phone (#1799) * use _tensor_stats in normalize processor (#1800) * fix(normalize_processor): correct device reference for tensor epsilon handling (#1801) * add point 5 add missing feature contracts (#1806) * Fix PR comments 1452 (#1807) * use key to determine image * Address rest of PR comments * use PolicyFeatures in transform_features --------- Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> --------- Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co> Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com> Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
This commit is contained in:
@@ -11,20 +11,88 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
from lerobot.processor.pipeline import (
|
||||
ActionProcessor,
|
||||
ComplementaryDataProcessor,
|
||||
EnvTransition,
|
||||
ObservationProcessor,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="to_batch_processor_action")
|
||||
class ToBatchProcessorAction(ActionProcessor):
|
||||
"""Process action component in-place, adding batch dimension if needed."""
|
||||
|
||||
def action(self, action):
|
||||
if not isinstance(action, Tensor) or action.dim() != 1:
|
||||
return action
|
||||
|
||||
return action.unsqueeze(0)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="to_batch_processor_observation")
|
||||
class ToBatchProcessorObservation(ObservationProcessor):
|
||||
"""Process observation component in-place, adding batch dimensions where needed."""
|
||||
|
||||
def observation(self, observation):
|
||||
# Process state observations - add batch dim if 1D
|
||||
for state_key in [OBS_STATE, OBS_ENV_STATE]:
|
||||
if state_key in observation:
|
||||
state_value = observation[state_key]
|
||||
if isinstance(state_value, Tensor) and state_value.dim() == 1:
|
||||
observation[state_key] = state_value.unsqueeze(0)
|
||||
|
||||
# Process single image observation - add batch dim if 3D
|
||||
if OBS_IMAGE in observation:
|
||||
image_value = observation[OBS_IMAGE]
|
||||
if isinstance(image_value, Tensor) and image_value.dim() == 3:
|
||||
observation[OBS_IMAGE] = image_value.unsqueeze(0)
|
||||
|
||||
# Process multiple image observations - add batch dim if 3D
|
||||
for key, value in observation.items():
|
||||
if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3:
|
||||
observation[key] = value.unsqueeze(0)
|
||||
return observation
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="to_batch_processor_complementary_data")
|
||||
class ToBatchProcessorComplementaryData(ComplementaryDataProcessor):
|
||||
"""Process complementary data in-place, handling task field batching."""
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
# Process task field - wrap string in list to add batch dimension
|
||||
if "task" in complementary_data:
|
||||
task_value = complementary_data["task"]
|
||||
if isinstance(task_value, str):
|
||||
complementary_data["task"] = [task_value]
|
||||
|
||||
# Process index field - add batch dim if 0D
|
||||
if "index" in complementary_data:
|
||||
index_value = complementary_data["index"]
|
||||
if isinstance(index_value, Tensor) and index_value.dim() == 0:
|
||||
complementary_data["index"] = index_value.unsqueeze(0)
|
||||
|
||||
# Process task_index field - add batch dim if 0D
|
||||
if "task_index" in complementary_data:
|
||||
task_index_value = complementary_data["task_index"]
|
||||
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
|
||||
complementary_data["task_index"] = task_index_value.unsqueeze(0)
|
||||
return complementary_data
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="to_batch_processor")
|
||||
class ToBatchProcessor:
|
||||
class ToBatchProcessor(ProcessorStep):
|
||||
"""Processor that adds batch dimensions to observations and actions when needed.
|
||||
|
||||
This processor ensures that observations and actions have proper batch dimensions for model processing:
|
||||
@@ -59,81 +127,16 @@ class ToBatchProcessor:
|
||||
```
|
||||
"""
|
||||
|
||||
to_batch_action_processor: ToBatchProcessorAction = field(default_factory=ToBatchProcessorAction)
|
||||
to_batch_observation_processor: ToBatchProcessorObservation = field(
|
||||
default_factory=ToBatchProcessorObservation
|
||||
)
|
||||
to_batch_complementary_data_processor: ToBatchProcessorComplementaryData = field(
|
||||
default_factory=ToBatchProcessorComplementaryData
|
||||
)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
self._process_observation(transition)
|
||||
self._process_action(transition)
|
||||
self._process_complementary_data(transition)
|
||||
transition = self.to_batch_action_processor(transition)
|
||||
transition = self.to_batch_observation_processor(transition)
|
||||
transition = self.to_batch_complementary_data_processor(transition)
|
||||
return transition
|
||||
|
||||
def _process_observation(self, transition: EnvTransition) -> None:
|
||||
"""Process observation component in-place, adding batch dimensions where needed."""
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None:
|
||||
return
|
||||
|
||||
# Process state observations - add batch dim if 1D
|
||||
for state_key in [OBS_STATE, OBS_ENV_STATE]:
|
||||
if state_key in observation:
|
||||
state_value = observation[state_key]
|
||||
if isinstance(state_value, Tensor) and state_value.dim() == 1:
|
||||
observation[state_key] = state_value.unsqueeze(0)
|
||||
|
||||
# Process single image observation - add batch dim if 3D
|
||||
if OBS_IMAGE in observation:
|
||||
image_value = observation[OBS_IMAGE]
|
||||
if isinstance(image_value, Tensor) and image_value.dim() == 3:
|
||||
observation[OBS_IMAGE] = image_value.unsqueeze(0)
|
||||
|
||||
# Process multiple image observations - add batch dim if 3D
|
||||
for key, value in observation.items():
|
||||
if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3:
|
||||
observation[key] = value.unsqueeze(0)
|
||||
|
||||
def _process_action(self, transition: EnvTransition) -> None:
|
||||
"""Process action component in-place, adding batch dimension if needed."""
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None and isinstance(action, Tensor) and action.dim() == 1:
|
||||
transition[TransitionKey.ACTION] = action.unsqueeze(0)
|
||||
|
||||
def _process_complementary_data(self, transition: EnvTransition) -> None:
|
||||
"""Process complementary data in-place, handling task field batching."""
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is None:
|
||||
return
|
||||
|
||||
# Process task field - wrap string in list to add batch dimension
|
||||
if "task" in complementary_data:
|
||||
task_value = complementary_data["task"]
|
||||
if isinstance(task_value, str):
|
||||
complementary_data["task"] = [task_value]
|
||||
|
||||
# Process index field - add batch dim if 0D
|
||||
if "index" in complementary_data:
|
||||
index_value = complementary_data["index"]
|
||||
if isinstance(index_value, Tensor) and index_value.dim() == 0:
|
||||
complementary_data["index"] = index_value.unsqueeze(0)
|
||||
|
||||
# Process task_index field - add batch dim if 0D
|
||||
if "task_index" in complementary_data:
|
||||
task_index_value = complementary_data["task_index"]
|
||||
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
|
||||
complementary_data["task_index"] = task_index_value.unsqueeze(0)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization."""
|
||||
return {}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Return state dictionary (empty for this processor)."""
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Load state dictionary (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset processor state (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
Reference in New Issue
Block a user