2025-07-10 10:39:11 +02:00
|
|
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
import logging.handlers
|
|
|
|
|
import os
|
|
|
|
|
import time
|
2025-10-20 23:34:24 +02:00
|
|
|
from dataclasses import dataclass, field
|
2025-07-10 10:39:11 +02:00
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
from lerobot.configs.types import PolicyFeature
|
|
|
|
|
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
|
|
|
|
|
|
|
|
|
# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
|
Add OpenPi, Pi0 and Pi0.5 (#1910)
* initial commit
* change device in test
* do detailed import
* adhere to python 3.11 syntax
* fix autodocstring
* additionally
* do same in other files
* add model. prefix to all keys in state dict
* use dummy stats
* add pi05
* also shorten action_steps
* fix test
* all test pass! and fix tokenizer max length between 05 and 0
* remove test
* fix transformer dependency
* fix test
* split pi0 and pi05 policy in seperate files
* fix test
* fix push to hub test
* add some comments, license and readme
* remove warning in config
* add pi05 to factory
* remove check
* rename action_horizon to chunk_size
* clean up padding of state and action (more in line with lerobot pi0)
* add openpi image transforms for training and add more flexibility to _preprocess_images similar to lerobot pi0
* fix key match from pytorch state dict (similar keys to openpi implementation now)
* also for pi05
* update to python 3.11
* revert to openpi transformer replace python 3.11
* fix(modeling pi0): nit warning message
* use safeauto_docstring
* fix: remove unused param
* fix from pretrained
* add preprocess tests
* also compile forward method
* Do not add model prefix to normalization
* use same name for action and state dim as lerobot pi0 and remove fixed image keys
* load from pretrained_path
* temp: hardcode base model
* fix override self.pretrained_path = None overwrite
* rename to loss
* remove additional image augmentations, lerobot dataset already does this
* Add docs
* put tests in test folder
* Add test to instatiate all base models
* go back to python 3.10
* update docs
* adapt docs pi05
* change docs: finetune base model options
* minor docs fixes and dependencies
* remove todo
* cast float64 to float32 for mps
* skip if no transformers
* fix tests
* add new models to modelcard
* add back init
* fix circular input
* feat: only run pi test on GPU
* remove require_nightly_gpu
* replace decorator test_pi0_openpi
* rename action_dim, state_dim to max_action_dim, max_state_dim
* fix doc and constants
* cleanup tests
* fix from pretrained
* fix tests
* add comment pi0 pi05 tests, add image features to pi0 pi05 hub tests
* fix, state is included in language not in flow head
* Move test to specific folder
* and paligemma task with newline
* remove add_special_tokens, not needed
* feedback pr
* Remove previous pi0 and rename pi0_openpi and pi05_openpi
* Add Quantile stats to LeRobotDataset (#1985)
* - Add RunningQuantileStats class for efficient histogram-based quantile computation
- Integrate quantile parameters (compute_quantiles, quantiles) into LeRobotDataset
- Support quantile computation during episode collection and aggregation
- Add comprehensive function-based test suite (24 tests) for quantile functionality
- Maintain full backward compatibility with existing stats computation
- Enable configurable quantiles (default: [0.01, 0.99]) for robust normalization
* style fixes, make quantiles computation by default to new datasets
* fix tests
* - Added DEFAULT_QUANTILES=[0.01, 0.10, 0.50, 0.90, 0.99] to be computed for each features instead of being chosen by the user
- Fortified tests.
* - add helper functions to reshape stats
- add missing test for quantiles
* - Add QUANTILE normalization mode to normalize the data with the 1st and 99th percentiles.
- Add QUANTILE10 normalization mode to normalize the data with the 10th and 90th percentiles.
* style fixes
* Added missing lisence
* Simplify compute_stats
* - added script `augment_dataset_quantile_stats.py` so that we can add quantile stats to existing v3 datasets that dont have quatniles
- modified quantile computation instead of using the edge for the value, interpolate the values in the bin
* rename pi0/pi05 files
* Remove open pi patch and use custom transformer branch for now
* renaming
* fix
* Revert "fix"
This reverts commit 1ea65730ac2cbca6e5869df734fbd4392561b3c6.
* fix naming
* feet(pi0/pi0.5): add pipeline (#2009)
* feat(processor): convert openpi model with processor
* TODO: Make test works
* fix(modeling_pi0openpi): update attention mask value and time scaling; improve task handling in tests
- Changed the attention mask value from `self.config.attention_mask_value` to a fixed value of `-2.3819763e38`.
- Updated time scaling in the `sample_noise` method to use a constant factor of `0.999` and an offset of `0.001`.
- Enhanced task handling in tests to ensure proper formatting and batch size consistency.
- Cleaned up commented-out test code for clarity.
* refactor(pi0): rename PI0OpenPIConfig and PI0OpenPIPolicy to PI0Config and PI0Policy
- Updated imports and references throughout the codebase to reflect the new naming convention.
- Introduced a new processor file for PI0 to handle pre-processing and post-processing steps.
- Adjusted tests to utilize the renamed classes, ensuring consistency and functionality.
- Enhanced clarity and maintainability by removing outdated naming conventions.
* refactor(pi05): rename PI0OpenPIPolicy to PI0Policy and update configuration
- Renamed `PI0OpenPIPolicy` to `PI0Policy` for consistency with naming conventions.
- Updated the `PI05OpenPIConfig` to include a new `tokenizer_max_length` attribute and changed the normalization mode for state from `MEAN_STD` to `QUANTILES`.
- Simplified model initialization in `PI05OpenPIPolicy` by removing unused `dataset_stats` parameter.
- Added a new processor class for `Pi05PrepareStateTokenizerProcessorStep` with `@dataclass` for improved readability.
- Introduced a test script to compare the integration of the PI0OpenPI policy with the original implementation, ensuring local testing compatibility.
* feat(processor): convert openpi model with processor
* TODO: Make test works
* fix(modeling_pi0openpi): update attention mask value and time scaling; improve task handling in tests
- Changed the attention mask value from `self.config.attention_mask_value` to a fixed value of `-2.3819763e38`.
- Updated time scaling in the `sample_noise` method to use a constant factor of `0.999` and an offset of `0.001`.
- Enhanced task handling in tests to ensure proper formatting and batch size consistency.
- Cleaned up commented-out test code for clarity.
* refactor(pi0): rename PI0OpenPIConfig and PI0OpenPIPolicy to PI0Config and PI0Policy
- Updated imports and references throughout the codebase to reflect the new naming convention.
- Introduced a new processor file for PI0 to handle pre-processing and post-processing steps.
- Adjusted tests to utilize the renamed classes, ensuring consistency and functionality.
- Enhanced clarity and maintainability by removing outdated naming conventions.
* refactor(pi05): rename PI0OpenPIPolicy to PI0Policy and update configuration
- Renamed `PI0OpenPIPolicy` to `PI0Policy` for consistency with naming conventions.
- Updated the `PI05OpenPIConfig` to include a new `tokenizer_max_length` attribute and changed the normalization mode for state from `MEAN_STD` to `QUANTILES`.
- Simplified model initialization in `PI05OpenPIPolicy` by removing unused `dataset_stats` parameter.
- Added a new processor class for `Pi05PrepareStateTokenizerProcessorStep` with `@dataclass` for improved readability.
- Introduced a test script to compare the integration of the PI0OpenPI policy with the original implementation, ensuring local testing compatibility.
* refactor(pi05): update imports and rename configuration classes
- Changed imports to reflect the new naming convention for PI05 configuration and policy classes.
- Renamed `PI05OpenPIConfig` to `PI05Config` and `PI05OpenPIPolicy` to `PI05Policy` for consistency.
- Introduced a new processor file for PI05, implementing pre-processing and post-processing steps.
- Updated tests to utilize the renamed classes, ensuring functionality and consistency across the codebase.
* update(pi05): increase tokenizer_max_length for improved processing
- Changed the `tokenizer_max_length` from 48 to 200 to enhance the model's capability in handling longer sequences.
- This adjustment aims to improve the overall performance and flexibility of the PI05 configuration.
* add default for state (max_state_dim)
* correct naming
* fix import
* cleanup code
* remove unused test
* us quantiles for action
* move to device
* remove discrete state assert
* fix pi05 test
* move pi05 to device
* use base models in comparison tests
* small renames for tests
* change number of tokens pi05 test
* fix openpi tokenization in test
* fix hub test
* fix test
* assert lerobot vs openpi tests
---------
Co-authored-by: Pepijn <pepijn@huggingface.co>
* add headers
* add back previously removed imports
* update if statement load processor with dataset stats
* remove to avoid circular import
* inject dataset stats for pretrained models
* check normalization before applying
* add link to quantile augument script
* fix(policies): transformers import for ci in PI0 & PI05 (#2039)
* fix(policies): transformers import for ci in PI0
* fix(policies): transformers import for ci in PI05
* test(processor): fix expected raise when normalization types are missing (#2040)
* switch normalization order pipeline for pi05
* Fix/quantiles script (#2064)
* refactor augment stats with quantiles script
add parallelization for faster processing
shift the quantile normalization between -1 1
* fix replay buffer tests
* fix comment
* overwrite the pipeline normalization features with the policy features
* remove double normalization overwrite
* cleanup from pretrained
* remove typo
* also set norm_map
* fix(augment_quantiles) images incorrectly divided by 255
* clamp quantiles
* link to lerobot base models
* rename tests
* encorperate PR feedback
* update docstring for RunningQuantileStats
* update doc links
* Revert "clamp quantiles"
This reverts commit 172207471c8f2cb62958e9a9e6a0535ba3ff67d4.
* fix self.paligemma
* fix tests related to quantiles that were scaled to [0,1], the new range is [-1, 1]
* fix libero doc and use different transformer branch
* use fix branch instead of feat
* update results libero
* add new line
* fix formatting
* precommit
* update results libero
* update libero doc
* update title
* final changes
* add quantiles to test
* run pre commit
---------
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Steven Palma <steven.palma@huggingface.co>
2025-10-02 13:14:45 +02:00
|
|
|
from lerobot.policies import ( # noqa: F401
|
|
|
|
|
ACTConfig,
|
|
|
|
|
DiffusionConfig,
|
|
|
|
|
PI0Config,
|
|
|
|
|
PI05Config,
|
|
|
|
|
SmolVLAConfig,
|
|
|
|
|
VQBeTConfig,
|
|
|
|
|
)
|
2025-07-10 10:39:11 +02:00
|
|
|
from lerobot.robots.robot import Robot
|
2025-09-25 15:36:47 +02:00
|
|
|
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR
|
2025-07-10 10:39:11 +02:00
|
|
|
from lerobot.utils.utils import init_logging
|
|
|
|
|
|
|
|
|
|
Action = torch.Tensor
|
|
|
|
|
|
|
|
|
|
# observation as received from the robot
|
|
|
|
|
RawObservation = dict[str, torch.Tensor]
|
|
|
|
|
|
|
|
|
|
# observation as those recorded in LeRobot dataset (keys are different)
|
|
|
|
|
LeRobotObservation = dict[str, torch.Tensor]
|
|
|
|
|
|
|
|
|
|
# observation, ready for policy inference (image keys resized)
|
|
|
|
|
Observation = dict[str, torch.Tensor]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def visualize_action_queue_size(action_queue_size: list[int]) -> None:
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
2025-09-29 10:49:36 +02:00
|
|
|
_, ax = plt.subplots()
|
2025-07-10 10:39:11 +02:00
|
|
|
ax.set_title("Action Queue Size Over Time")
|
|
|
|
|
ax.set_xlabel("Environment steps")
|
|
|
|
|
ax.set_ylabel("Action Queue Size")
|
|
|
|
|
ax.set_ylim(0, max(action_queue_size) * 1.1)
|
|
|
|
|
ax.grid(True, alpha=0.3)
|
|
|
|
|
ax.plot(range(len(action_queue_size)), action_queue_size)
|
|
|
|
|
plt.show()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def map_robot_keys_to_lerobot_features(robot: Robot) -> dict[str, dict]:
|
2025-09-25 15:36:47 +02:00
|
|
|
return hw_to_dataset_features(robot.observation_features, OBS_STR, use_video=False)
|
2025-07-10 10:39:11 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_image_key(k: str) -> bool:
|
|
|
|
|
return k.startswith(OBS_IMAGES)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def resize_robot_observation_image(image: torch.tensor, resize_dims: tuple[int, int, int]) -> torch.tensor:
|
|
|
|
|
assert image.ndim == 3, f"Image must be (C, H, W)! Received {image.shape}"
|
|
|
|
|
# (H, W, C) -> (C, H, W) for resizing from robot obsevation resolution to policy image resolution
|
|
|
|
|
image = image.permute(2, 0, 1)
|
|
|
|
|
dims = (resize_dims[1], resize_dims[2])
|
|
|
|
|
# Add batch dimension for interpolate: (C, H, W) -> (1, C, H, W)
|
|
|
|
|
image_batched = image.unsqueeze(0)
|
|
|
|
|
# Interpolate and remove batch dimension: (1, C, H, W) -> (C, H, W)
|
|
|
|
|
resized = torch.nn.functional.interpolate(image_batched, size=dims, mode="bilinear", align_corners=False)
|
|
|
|
|
|
|
|
|
|
return resized.squeeze(0)
|
|
|
|
|
|
|
|
|
|
|
2025-10-07 15:10:31 +02:00
|
|
|
# TODO(Steven): Consider implementing a pipeline step for this
|
2025-07-10 10:39:11 +02:00
|
|
|
def raw_observation_to_observation(
|
|
|
|
|
raw_observation: RawObservation,
|
|
|
|
|
lerobot_features: dict[str, dict],
|
|
|
|
|
policy_image_features: dict[str, PolicyFeature],
|
|
|
|
|
) -> Observation:
|
|
|
|
|
observation = {}
|
|
|
|
|
|
|
|
|
|
observation = prepare_raw_observation(raw_observation, lerobot_features, policy_image_features)
|
|
|
|
|
for k, v in observation.items():
|
|
|
|
|
if isinstance(v, torch.Tensor): # VLAs present natural-language instructions in observations
|
|
|
|
|
if "image" in k:
|
|
|
|
|
# Policy expects images in shape (B, C, H, W)
|
2025-10-07 15:10:31 +02:00
|
|
|
observation[k] = prepare_image(v).unsqueeze(0)
|
2025-07-10 10:39:11 +02:00
|
|
|
else:
|
|
|
|
|
observation[k] = v
|
|
|
|
|
|
|
|
|
|
return observation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_image(image: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
"""Minimal preprocessing to turn int8 images to float32 in [0, 1], and create a memory-contiguous tensor"""
|
|
|
|
|
image = image.type(torch.float32) / 255
|
|
|
|
|
image = image.contiguous()
|
|
|
|
|
|
|
|
|
|
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_state_from_raw_observation(
|
|
|
|
|
lerobot_obs: RawObservation,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
"""Extract the state from a raw observation."""
|
|
|
|
|
state = torch.tensor(lerobot_obs[OBS_STATE])
|
|
|
|
|
|
|
|
|
|
if state.ndim == 1:
|
|
|
|
|
state = state.unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
return state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_images_from_raw_observation(
|
|
|
|
|
lerobot_obs: RawObservation,
|
|
|
|
|
camera_key: str,
|
|
|
|
|
) -> dict[str, torch.Tensor]:
|
|
|
|
|
"""Extract the images from a raw observation."""
|
|
|
|
|
return torch.tensor(lerobot_obs[camera_key])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_lerobot_observation(
|
|
|
|
|
robot_obs: RawObservation,
|
|
|
|
|
lerobot_features: dict[str, dict],
|
|
|
|
|
) -> LeRobotObservation:
|
|
|
|
|
"""Make a lerobot observation from a raw observation."""
|
2025-09-25 15:36:47 +02:00
|
|
|
return build_dataset_frame(lerobot_features, robot_obs, prefix=OBS_STR)
|
2025-07-10 10:39:11 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_raw_observation(
|
|
|
|
|
robot_obs: RawObservation,
|
|
|
|
|
lerobot_features: dict[str, dict],
|
|
|
|
|
policy_image_features: dict[str, PolicyFeature],
|
|
|
|
|
) -> Observation:
|
|
|
|
|
"""Matches keys from the raw robot_obs dict to the keys expected by a given policy (passed as
|
|
|
|
|
policy_image_features)."""
|
|
|
|
|
# 1. {motor.pos1:value1, motor.pos2:value2, ..., laptop:np.ndarray} ->
|
|
|
|
|
# -> {observation.state:[value1,value2,...], observation.images.laptop:np.ndarray}
|
|
|
|
|
lerobot_obs = make_lerobot_observation(robot_obs, lerobot_features)
|
|
|
|
|
|
|
|
|
|
# 2. Greps all observation.images.<> keys
|
|
|
|
|
image_keys = list(filter(is_image_key, lerobot_obs))
|
|
|
|
|
# state's shape is expected as (B, state_dim)
|
|
|
|
|
state_dict = {OBS_STATE: extract_state_from_raw_observation(lerobot_obs)}
|
|
|
|
|
image_dict = {
|
|
|
|
|
image_k: extract_images_from_raw_observation(lerobot_obs, image_k) for image_k in image_keys
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Turns the image features to (C, H, W) with H, W matching the policy image features.
|
|
|
|
|
# This reduces the resolution of the images
|
|
|
|
|
image_dict = {
|
|
|
|
|
key: resize_robot_observation_image(torch.tensor(lerobot_obs[key]), policy_image_features[key].shape)
|
|
|
|
|
for key in image_keys
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if "task" in robot_obs:
|
|
|
|
|
state_dict["task"] = robot_obs["task"]
|
|
|
|
|
|
|
|
|
|
return {**state_dict, **image_dict}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_logger(name: str, log_to_file: bool = True) -> logging.Logger:
|
|
|
|
|
"""
|
|
|
|
|
Get a logger using the standardized logging setup from utils.py.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
name: Logger name (e.g., 'policy_server', 'robot_client')
|
|
|
|
|
log_to_file: Whether to also log to a file
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Configured logger instance
|
|
|
|
|
"""
|
|
|
|
|
# Create logs directory if logging to file
|
|
|
|
|
if log_to_file:
|
|
|
|
|
os.makedirs("logs", exist_ok=True)
|
|
|
|
|
log_file = Path(f"logs/{name}_{int(time.time())}.log")
|
|
|
|
|
else:
|
|
|
|
|
log_file = None
|
|
|
|
|
|
|
|
|
|
# Initialize the standardized logging
|
|
|
|
|
init_logging(log_file=log_file, display_pid=False)
|
|
|
|
|
|
|
|
|
|
# Return a named logger
|
|
|
|
|
return logging.getLogger(name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class TimedData:
|
|
|
|
|
"""A data object with timestamp and timestep information.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
timestamp: Unix timestamp relative to data's creation.
|
|
|
|
|
data: The actual data to wrap a timestamp around.
|
|
|
|
|
timestep: The timestep of the data.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
timestamp: float
|
|
|
|
|
timestep: int
|
|
|
|
|
|
|
|
|
|
def get_timestamp(self):
|
|
|
|
|
return self.timestamp
|
|
|
|
|
|
|
|
|
|
def get_timestep(self):
|
|
|
|
|
return self.timestep
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class TimedAction(TimedData):
|
|
|
|
|
action: Action
|
|
|
|
|
|
|
|
|
|
def get_action(self):
|
|
|
|
|
return self.action
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class TimedObservation(TimedData):
|
|
|
|
|
observation: RawObservation
|
|
|
|
|
must_go: bool = False
|
|
|
|
|
|
|
|
|
|
def get_observation(self):
|
|
|
|
|
return self.observation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class FPSTracker:
|
|
|
|
|
"""Utility class to track FPS metrics over time."""
|
|
|
|
|
|
|
|
|
|
target_fps: float
|
|
|
|
|
first_timestamp: float = None
|
|
|
|
|
total_obs_count: int = 0
|
|
|
|
|
|
|
|
|
|
def calculate_fps_metrics(self, current_timestamp: float) -> dict[str, float]:
|
|
|
|
|
"""Calculate average FPS vs target"""
|
|
|
|
|
self.total_obs_count += 1
|
|
|
|
|
|
|
|
|
|
# Initialize first observation time
|
|
|
|
|
if self.first_timestamp is None:
|
|
|
|
|
self.first_timestamp = current_timestamp
|
|
|
|
|
|
|
|
|
|
# Calculate overall average FPS (since start)
|
|
|
|
|
total_duration = current_timestamp - self.first_timestamp
|
|
|
|
|
avg_fps = (self.total_obs_count - 1) / total_duration if total_duration > 1e-6 else 0.0
|
|
|
|
|
|
|
|
|
|
return {"avg_fps": avg_fps, "target_fps": self.target_fps}
|
|
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
|
"""Reset the FPS tracker state"""
|
|
|
|
|
self.first_timestamp = None
|
|
|
|
|
self.total_obs_count = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class RemotePolicyConfig:
|
|
|
|
|
policy_type: str
|
|
|
|
|
pretrained_name_or_path: str
|
|
|
|
|
lerobot_features: dict[str, PolicyFeature]
|
|
|
|
|
actions_per_chunk: int
|
|
|
|
|
device: str = "cpu"
|
2025-10-20 23:34:24 +02:00
|
|
|
rename_map: dict[str, str] = field(default_factory=dict)
|
2025-07-10 10:39:11 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def _compare_observation_states(obs1_state: torch.Tensor, obs2_state: torch.Tensor, atol: float) -> bool:
|
|
|
|
|
"""Check if two observation states are similar, under a tolerance threshold"""
|
|
|
|
|
return bool(torch.linalg.norm(obs1_state - obs2_state) < atol)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def observations_similar(
|
|
|
|
|
obs1: TimedObservation, obs2: TimedObservation, lerobot_features: dict[str, dict], atol: float = 1
|
|
|
|
|
) -> bool:
|
|
|
|
|
"""Check if two observations are similar, under a tolerance threshold. Measures distance between
|
|
|
|
|
observations as the difference in joint-space between the two observations.
|
|
|
|
|
|
|
|
|
|
NOTE(fracapuano): This is a very simple check, and it is enough for the current use case.
|
|
|
|
|
An immediate next step is to use (fast) perceptual difference metrics comparing some camera views,
|
|
|
|
|
to surpass this joint-space similarity check.
|
|
|
|
|
"""
|
|
|
|
|
obs1_state = extract_state_from_raw_observation(
|
|
|
|
|
make_lerobot_observation(obs1.get_observation(), lerobot_features)
|
|
|
|
|
)
|
|
|
|
|
obs2_state = extract_state_from_raw_observation(
|
|
|
|
|
make_lerobot_observation(obs2.get_observation(), lerobot_features)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return _compare_observation_states(obs1_state, obs2_state, atol=atol)
|