Compare commits

...

19 Commits

Author SHA1 Message Date
Khalil Meftah
ef8bfffbd7 fix(rl): enhance intervention handling in actor and learner 2026-04-26 23:09:33 +02:00
Khalil Meftah
f887ab3f6a fix(rl): improve action processing for discrete and continuous actions 2026-04-26 22:47:52 +02:00
Khalil Meftah
c2556439e5 fix(rl): postprocess action in actor 2026-04-26 18:15:04 +02:00
Khalil Meftah
d2a046dfc5 fix(rl): mirror gym_manipulator in actor 2026-04-26 18:11:26 +02:00
Khalil Meftah
613d581f6c remove debug 2026-04-26 18:08:13 +02:00
Khalil Meftah
58b6d844c4 debug 2026-04-26 17:33:15 +02:00
Khalil Meftah
30e1886b64 fix(rl): merge environment and action-processor info in transition processing 2026-04-26 17:12:37 +02:00
Khalil Meftah
9c9064e5be fix(rl): update neutral gripper action 2026-04-26 16:42:53 +02:00
Khalil Meftah
494f469a2b fix(rl): clarify discrete gripper action mapping in GripperVelocityToJoint for SO100 2026-04-26 16:41:55 +02:00
Khalil Meftah
cd105f65cb fix(rl): add time limit processor to environment pipeline 2026-04-26 16:38:20 +02:00
Khalil Meftah
9c2af818ff fix(rl): correctly wire HIL-SERL gripper penalty through processor pipeline 2026-04-26 16:36:21 +02:00
Khalil Meftah
6495bb9706 add processor to main 2026-04-24 17:06:57 +02:00
Steven Palma
580d818aa9 fix(dataset): no default overwrite in lerobot tool recompute stats (#3452) 2026-04-24 15:07:19 +02:00
Steven Palma
587aa82021 fix(imports): realsense import name is platform dependent (#3451) 2026-04-24 12:55:38 +02:00
Chuyao Shen
12b88fce02 not use dataclass (#3414)
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2026-04-24 11:26:59 +02:00
masato-ka
fc6c94c82a fix(sarm): handle BaseModelOutputWithPooling from transformers 5.x in… (#3419)
* fix(sarm): handle BaseModelOutputWithPooling from transformers 5.x in CLIP encoding

In transformers 5.x, CLIPModel.get_image_features() and get_text_features()
return BaseModelOutputWithPooling instead of a plain torch.FloatTensor.
Added isinstance check to extract pooler_output when the return value is not
a tensor, maintaining backward compatibility with transformers 4.x.

Fixes AttributeError: 'BaseModelOutputWithPooling' object has no attribute 'detach'

* Adding assertion check for pooler_output of CLIP. This change is response to below comment.
https://github.com/huggingface/lerobot/pull/3419#discussion_r3112594387

* Adding assertion check for pooler_output of CLIP. This change is response to below comment. Change to simple check and rise
https://github.com/huggingface/lerobot/pull/3419#discussion_r3126953776

---------
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2026-04-23 16:26:58 +02:00
Steven Palma
1add460678 fix(policy): loss normalization for padded actions in ACT, Diffusion, and MultiTaskDiT (#3442)
* Fix loss normalization for padded actions in ACT, Diffusion, and MultiTaskDiT

When action_is_pad masks out padded timesteps, the subsequent .mean()
still divides by the total element count (including zeroed-out padding),
underestimating the loss. With 60-70% padding this can cut the effective
gradient signal by 2-3x.

Replace mask-then-mean with mask-then-sum / valid-count for all three
affected policies. TDMPC is not affected because it sums over time
before averaging over batch.

Fixes #3353

* linting

Co-authored-by: whats2000 <60466660+whats2000@users.noreply.github.com>
Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>

* Update src/lerobot/policies/diffusion/modeling_diffusion.py

Co-authored-by: whats2000 <60466660+whats2000@users.noreply.github.com>
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>

* Update src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py

Co-authored-by: whats2000 <60466660+whats2000@users.noreply.github.com>
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>

* Update src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py

Co-authored-by: whats2000 <60466660+whats2000@users.noreply.github.com>
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>

* apply ACT loss normalization suggestion from review

Divide by num_valid (timesteps * action_dim) instead of just timesteps,
matching the diffusion/multi_task_dit fix. Addresses review from
@whats2000 (https://github.com/huggingface/lerobot/pull/3377#discussion_r3106845791).

* fix(test): update safetensor act

---------

Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Yufeng He <40085740+he-yufeng@users.noreply.github.com>
Co-authored-by: Maxime Ellerbach <maxime@ellerbach.net>
Co-authored-by: whats2000 <60466660+whats2000@users.noreply.github.com>
2026-04-23 15:23:54 +02:00
Qi Jia
4587c2b648 fix xvla docs (#3291)
Co-authored-by: Qi Jia <kaufou@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-04-23 14:50:32 +02:00
whats2000
2236cdb302 fix(smolvla): correct loss normalization for padded actions (#3434)
Apply the same per-scalar-mean fix to SmolVLA that #3377 landed for
ACT / Diffusion / MultiTaskDiT. The pre-patch form applies the
`action_is_pad` mask to zero out padded timesteps, then calls `.mean()`
(or `.mean(dim=(1, 2))`). Because `.mean()` divides by the total number
of elements including the zeroed padding, the loss is diluted by the
padding fraction.

Fixed by normalizing only over valid (non-padded) scalar entries:

    num_valid = ((~actions_is_pad).sum(...) * losses.shape[-1]).clamp_min(1)
    loss = losses.sum(...) / num_valid

`clamp_min(1)` preserves the all-padded-batch edge case (0/1 = 0). Both
reduction paths are updated. Behavior when `action_is_pad` is missing is
unchanged (`losses.mean()`).

Empirical A/B on aloha_sim_transfer_cube_human (chunk_size=40, batch=2,
30 steps, fixed seed, GB200) shows `loss_A / loss_B = 0.9672 (±0.088)` —
same direction and magnitude as PR #3377's `loss_A / loss_C ≈ 0.96` for
ACT. Heavier-padding recipes will see a larger gap.

Refs: #3353 (original report for ACT), #3377 (fix for the other three
policies).
2026-04-23 10:34:11 +02:00
20 changed files with 254 additions and 99 deletions

View File

@@ -220,7 +220,7 @@ REAL_DIM = 12
# Postprocessing: Trim 20D predictions to 12D for deployment # Postprocessing: Trim 20D predictions to 12D for deployment
``` ```
See the [action_hub.py](/home/jade_choghari/robot/lerobot/src/lerobot/policies/xvla/action_hub.py) implementation for details. See the [action_hub.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/xvla/action_hub.py) implementation for details.
#### Auto Action Mode (Recommended) #### Auto Action Mode (Recommended)
@@ -519,9 +519,9 @@ If you use X-VLA in your research, please cite:
- [X-VLA Paper](https://arxiv.org/pdf/2510.10274) - [X-VLA Paper](https://arxiv.org/pdf/2510.10274)
- [LeRobot Documentation](https://github.com/huggingface/lerobot) - [LeRobot Documentation](https://github.com/huggingface/lerobot)
- [Action Registry Implementation](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/action_hub.py) - [Action Registry Implementation](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/xvla/action_hub.py)
- [Processor Implementation](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/processor_xvla.py) - [Processor Implementation](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/xvla/processor_xvla.py)
- [Model Configuration](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/configuration_xvla.py) - [Model Configuration](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/xvla/configuration_xvla.py)
## Contributing ## Contributing

View File

@@ -17,6 +17,7 @@ Provides the RealSenseCamera class for capturing frames from Intel RealSense cam
""" """
import logging import logging
import sys
import time import time
from threading import Event, Lock, Thread from threading import Event, Lock, Thread
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
@@ -41,6 +42,7 @@ from ..utils import get_cv2_rotation
from .configuration_realsense import RealSenseCameraConfig from .configuration_realsense import RealSenseCameraConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
pkg_name = "pyrealsense2-macosx" if sys.platform == "darwin" else "pyrealsense2"
class RealSenseCamera(Camera): class RealSenseCamera(Camera):
@@ -114,7 +116,7 @@ class RealSenseCamera(Camera):
Args: Args:
config: The configuration settings for the camera. config: The configuration settings for the camera.
""" """
require_package("pyrealsense2", extra="intelrealsense") require_package(pkg_name, extra="intelrealsense", import_name="pyrealsense2")
super().__init__(config) super().__init__(config)
self.config = config self.config = config

View File

@@ -142,9 +142,10 @@ class ACTPolicy(PreTrainedPolicy):
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = ( abs_err = F.l1_loss(batch[ACTION], actions_hat, reduction="none")
F.l1_loss(batch[ACTION], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) valid_mask = ~batch["action_is_pad"].unsqueeze(-1)
).mean() num_valid = valid_mask.sum() * abs_err.shape[-1]
l1_loss = (abs_err * valid_mask).sum() / num_valid.clamp_min(1)
loss_dict = {"l1_loss": l1_loss.item()} loss_dict = {"l1_loss": l1_loss.item()}
if self.config.use_vae: if self.config.use_vae:

View File

@@ -380,7 +380,9 @@ class DiffusionModel(nn.Module):
f"{self.config.do_mask_loss_for_padding=}." f"{self.config.do_mask_loss_for_padding=}."
) )
in_episode_bound = ~batch["action_is_pad"] in_episode_bound = ~batch["action_is_pad"]
loss = loss * in_episode_bound.unsqueeze(-1) mask = in_episode_bound.unsqueeze(-1)
num_valid = mask.sum() * loss.shape[-1]
return (loss * mask).sum() / num_valid.clamp_min(1)
return loss.mean() return loss.mean()

View File

@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@@ -174,17 +173,14 @@ N_COLOR_CHANNELS = 3
# config # config
@dataclass
class GR00TN15Config(PretrainedConfig): class GR00TN15Config(PretrainedConfig):
model_type = "gr00t_n1_5" model_type = "gr00t_n1_5"
backbone_cfg: dict = field(init=False, metadata={"help": "Backbone configuration."})
action_head_cfg: dict = field(init=False, metadata={"help": "Action head configuration."}) backbone_cfg: dict
action_head_cfg: dict
action_horizon: int = field(init=False, metadata={"help": "Action horizon."}) action_horizon: int
action_dim: int
action_dim: int = field(init=False, metadata={"help": "Action dimension."}) compute_dtype: str = "float32"
compute_dtype: str = field(default="float32", metadata={"help": "Compute dtype."})
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)

View File

@@ -688,8 +688,9 @@ class DiffusionObjective(nn.Module):
loss = F.mse_loss(predicted, target, reduction="none") loss = F.mse_loss(predicted, target, reduction="none")
if self.do_mask_loss_for_padding and "action_is_pad" in batch: if self.do_mask_loss_for_padding and "action_is_pad" in batch:
valid_actions = ~batch["action_is_pad"] mask = ~batch["action_is_pad"].unsqueeze(-1)
loss = loss * valid_actions.unsqueeze(-1) num_valid = mask.sum() * loss.shape[-1]
return (loss * mask).sum() / num_valid.clamp_min(1)
return loss.mean() return loss.mean()
@@ -752,8 +753,9 @@ class FlowMatchingObjective(nn.Module):
loss = F.mse_loss(predicted_velocity, target_velocity, reduction="none") loss = F.mse_loss(predicted_velocity, target_velocity, reduction="none")
if self.do_mask_loss_for_padding and "action_is_pad" in batch: if self.do_mask_loss_for_padding and "action_is_pad" in batch:
valid_mask = ~batch["action_is_pad"] mask = ~batch["action_is_pad"].unsqueeze(-1)
loss = loss * valid_mask.unsqueeze(-1) num_valid = mask.sum() * loss.shape[-1]
return (loss * mask).sum() / num_valid.clamp_min(1)
return loss.mean() return loss.mean()

View File

@@ -455,7 +455,13 @@ class SARMEncodingProcessorStep(ProcessorStep):
inputs = {k: v.to(self.device) for k, v in inputs.items()} inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Get image embeddings # Get image embeddings
embeddings = self.clip_model.get_image_features(**inputs).detach().cpu() # transformers 5.x returns BaseModelOutputWithPooling instead of a plain tensor
output = self.clip_model.get_image_features(**inputs)
if not isinstance(output, torch.Tensor):
output = output.pooler_output
if output is None:
raise ValueError("pooler_output should not be None for CLIP models.")
embeddings = output.detach().cpu()
# Handle single frame case # Handle single frame case
if embeddings.dim() == 1: if embeddings.dim() == 1:
@@ -482,7 +488,13 @@ class SARMEncodingProcessorStep(ProcessorStep):
inputs = self.clip_processor.tokenizer([text], return_tensors="pt", padding=True, truncation=True) inputs = self.clip_processor.tokenizer([text], return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(self.device) for k, v in inputs.items()} inputs = {k: v.to(self.device) for k, v in inputs.items()}
text_embedding = self.clip_model.get_text_features(**inputs).detach().cpu() # transformers 5.x returns BaseModelOutputWithPooling instead of a plain tensor
output = self.clip_model.get_text_features(**inputs)
if not isinstance(output, torch.Tensor):
output = output.pooler_output
if output is None:
raise ValueError("pooler_output should not be None for CLIP models.")
text_embedding = output.detach().cpu()
text_embedding = text_embedding.expand(batch_size, -1) text_embedding = text_embedding.expand(batch_size, -1)
return text_embedding return text_embedding

View File

@@ -394,13 +394,21 @@ class SmolVLAPolicy(PreTrainedPolicy):
loss_dict["losses_after_rm_padding"] = losses.clone().mean().item() loss_dict["losses_after_rm_padding"] = losses.clone().mean().item()
if reduction == "none": if reduction == "none":
# Return per-sample losses (B,) by averaging over time and action dims # Return per-sample losses (B,) by averaging over valid (time, action) entries
if actions_is_pad is None:
per_sample_loss = losses.mean(dim=(1, 2)) per_sample_loss = losses.mean(dim=(1, 2))
else:
num_valid = ((~actions_is_pad).sum(dim=1) * losses.shape[-1]).clamp_min(1)
per_sample_loss = losses.sum(dim=(1, 2)) / num_valid
loss_dict["loss"] = per_sample_loss.mean().item() loss_dict["loss"] = per_sample_loss.mean().item()
return per_sample_loss, loss_dict return per_sample_loss, loss_dict
else: else:
# Default: return scalar mean loss # Default: return scalar mean loss over valid (time, action) entries
if actions_is_pad is None:
loss = losses.mean() loss = losses.mean()
else:
num_valid = ((~actions_is_pad).sum() * losses.shape[-1]).clamp_min(1)
loss = losses.sum() / num_valid
loss_dict["loss"] = loss.item() loss_dict["loss"] = loss.item()
return loss, loss_dict return loss, loss_dict

View File

@@ -321,6 +321,7 @@ class GymHILAdapterProcessorStep(ProcessorStep):
This step normalizes the `transition` object by: This step normalizes the `transition` object by:
1. Copying `teleop_action` from `info` to `complementary_data`. 1. Copying `teleop_action` from `info` to `complementary_data`.
2. Copying `is_intervention` from `info` (using the string key) to `info` (using the enum key). 2. Copying `is_intervention` from `info` (using the string key) to `info` (using the enum key).
3. Copying `discrete_penalty` from `info` to `complementary_data`.
""" """
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
@@ -330,6 +331,9 @@ class GymHILAdapterProcessorStep(ProcessorStep):
if TELEOP_ACTION_KEY in info: if TELEOP_ACTION_KEY in info:
complementary_data[TELEOP_ACTION_KEY] = info[TELEOP_ACTION_KEY] complementary_data[TELEOP_ACTION_KEY] = info[TELEOP_ACTION_KEY]
if DISCRETE_PENALTY_KEY in info:
complementary_data[DISCRETE_PENALTY_KEY] = info[DISCRETE_PENALTY_KEY]
if "is_intervention" in info: if "is_intervention" in info:
info[TeleopEvents.IS_INTERVENTION] = info["is_intervention"] info[TeleopEvents.IS_INTERVENTION] = info["is_intervention"]
@@ -348,18 +352,24 @@ class GymHILAdapterProcessorStep(ProcessorStep):
@ProcessorStepRegistry.register("gripper_penalty_processor") @ProcessorStepRegistry.register("gripper_penalty_processor")
class GripperPenaltyProcessorStep(ProcessorStep): class GripperPenaltyProcessorStep(ProcessorStep):
""" """
Applies a penalty for inefficient gripper usage. Applies a small per-transition cost on the discrete gripper action.
This step penalizes actions that attempt to close an already closed gripper or Fires only when the commanded action would actually transition the gripper
open an already open one, based on position thresholds. from one extreme to the other (close-while-open or open-while-closed).
This discourages gripper oscillation while leaving "stay" and saturating-further
commands unpenalized.
Attributes: Attributes:
penalty: The negative reward value to apply. penalty: The negative reward value to apply.
max_gripper_pos: The maximum position value for the gripper, used for normalization. max_gripper_pos: The maximum position value for the gripper, used for normalization.
open_threshold: Normalized state below which the gripper is considered "open".
closed_threshold: Normalized state above which the gripper is considered "closed".
""" """
penalty: float = -0.01 penalty: float = -0.02
max_gripper_pos: float = 30.0 max_gripper_pos: float = 30.0
open_threshold: float = 0.1
closed_threshold: float = 0.9
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
""" """
@@ -391,9 +401,13 @@ class GripperPenaltyProcessorStep(ProcessorStep):
gripper_state_normalized = current_gripper_pos / self.max_gripper_pos gripper_state_normalized = current_gripper_pos / self.max_gripper_pos
# Calculate penalty boolean as in original # Calculate penalty boolean as in original
gripper_penalty_bool = (gripper_state_normalized < 0.5 and gripper_action_normalized > 0.5) or ( # - currently open AND target is closed -> close transition
gripper_state_normalized > 0.75 and gripper_action_normalized < 0.5 # - currently closed AND target is open -> open transition
) is_open = gripper_state_normalized < self.open_threshold
is_closed = gripper_state_normalized > self.closed_threshold
cmd_close = gripper_action_normalized > self.closed_threshold
cmd_open = gripper_action_normalized < self.open_threshold
gripper_penalty_bool = (is_open and cmd_close) or (is_closed and cmd_open)
gripper_penalty = self.penalty * int(gripper_penalty_bool) gripper_penalty = self.penalty * int(gripper_penalty_bool)
@@ -409,11 +423,14 @@ class GripperPenaltyProcessorStep(ProcessorStep):
Returns the configuration of the step for serialization. Returns the configuration of the step for serialization.
Returns: Returns:
A dictionary containing the penalty value and max gripper position. A dictionary containing the penalty value, max gripper position,
and the open/closed thresholds.
""" """
return { return {
"penalty": self.penalty, "penalty": self.penalty,
"max_gripper_pos": self.max_gripper_pos, "max_gripper_pos": self.max_gripper_pos,
"open_threshold": self.open_threshold,
"closed_threshold": self.closed_threshold,
} }
def reset(self) -> None: def reset(self) -> None:

View File

@@ -134,6 +134,15 @@ class _NormalizationMixin:
if self.dtype is None: if self.dtype is None:
self.dtype = torch.float32 self.dtype = torch.float32
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
self._reshape_visual_stats()
def _reshape_visual_stats(self) -> None:
"""Reshape visual stats from ``[C]`` to ``[C, 1, 1]`` for image broadcasting."""
for key, feature in self.features.items():
if feature.type == FeatureType.VISUAL and key in self._tensor_stats:
for stat_name, stat_tensor in self._tensor_stats[key].items():
if isinstance(stat_tensor, Tensor) and stat_tensor.ndim == 1:
self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1)
def to( def to(
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
@@ -152,6 +161,7 @@ class _NormalizationMixin:
if dtype is not None: if dtype is not None:
self.dtype = dtype self.dtype = dtype
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
self._reshape_visual_stats()
return self return self
def state_dict(self) -> dict[str, Tensor]: def state_dict(self) -> dict[str, Tensor]:
@@ -201,6 +211,7 @@ class _NormalizationMixin:
# Don't load from state_dict, keep the explicitly provided stats # Don't load from state_dict, keep the explicitly provided stats
# But ensure _tensor_stats is properly initialized # But ensure _tensor_stats is properly initialized
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment] self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment]
self._reshape_visual_stats()
return return
# Normal behavior: load stats from state_dict # Normal behavior: load stats from state_dict
@@ -211,6 +222,7 @@ class _NormalizationMixin:
self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to( self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to(
dtype=torch.float32, device=self.device dtype=torch.float32, device=self.device
) )
self._reshape_visual_stats()
# Reconstruct the original stats dict from tensor stats for compatibility with to() method # Reconstruct the original stats dict from tensor stats for compatibility with to() method
# and other functions that rely on self.stats # and other functions that rely on self.stats

View File

@@ -60,7 +60,7 @@ from torch.multiprocessing import Event, Queue
from lerobot.cameras import opencv # noqa: F401 from lerobot.cameras import opencv # noqa: F401
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.policies import make_policy from lerobot.policies import make_policy, make_pre_post_processors
from lerobot.policies.sac.modeling_sac import SACPolicy from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.robots import so_follower # noqa: F401 from lerobot.robots import so_follower # noqa: F401
from lerobot.teleoperators import gamepad, so_leader # noqa: F401 from lerobot.teleoperators import gamepad, so_leader # noqa: F401
@@ -89,9 +89,9 @@ from lerobot.utils.utils import (
) )
from .gym_manipulator import ( from .gym_manipulator import (
create_transition,
make_processors, make_processors,
make_robot_env, make_robot_env,
reset_and_build_transition,
step_env_and_process_transition, step_env_and_process_transition,
) )
from .process import ProcessSignalHandler from .process import ProcessSignalHandler
@@ -261,13 +261,12 @@ def act_with_policy(
policy = policy.eval() policy = policy.eval()
assert isinstance(policy, nn.Module) assert isinstance(policy, nn.Module)
obs, info = online_env.reset() preprocessor, postprocessor = make_pre_post_processors(
env_processor.reset() policy_cfg=cfg.policy,
action_processor.reset() dataset_stats=cfg.policy.dataset_stats,
)
# Process initial observation transition = reset_and_build_transition(online_env, env_processor, action_processor)
transition = create_transition(observation=obs, info=info)
transition = env_processor(transition)
# NOTE: For the moment we will solely handle the case of a single environment # NOTE: For the moment we will solely handle the case of a single environment
sum_reward_episode = 0 sum_reward_episode = 0
@@ -291,8 +290,21 @@ def act_with_policy(
# Time policy inference and check if it meets FPS requirement # Time policy inference and check if it meets FPS requirement
with policy_timer: with policy_timer:
# Extract observation from transition for policy normalized_observation = preprocessor.process_observation(observation)
action = policy.select_action(batch=observation) action = policy.select_action(batch=normalized_observation)
# Unnormalize only the continuous part. When `num_discrete_actions` is set,
# `select_action` concatenates an argmax index in env space at the last dim;
# action stats cover the continuous dims only, so feeding the full vector to
# the unnormalizer would shape-mismatch and would also corrupt the discrete
# index by treating it as a normalized value.
if cfg.policy.num_discrete_actions is not None:
continuous_action = postprocessor.process_action(action[..., :-1])
discrete_action = action[..., -1:].to(
device=continuous_action.device, dtype=continuous_action.dtype
)
action = torch.cat([continuous_action, discrete_action], dim=-1)
else:
action = postprocessor.process_action(action)
policy_fps = policy_timer.fps_last policy_fps = policy_timer.fps_last
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step) log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
@@ -326,7 +338,8 @@ def act_with_policy(
# Check for intervention from transition info # Check for intervention from transition info
intervention_info = new_transition[TransitionKey.INFO] intervention_info = new_transition[TransitionKey.INFO]
if intervention_info.get(TeleopEvents.IS_INTERVENTION, False): is_intervention = bool(intervention_info.get(TeleopEvents.IS_INTERVENTION, False))
if is_intervention:
episode_intervention = True episode_intervention = True
episode_intervention_steps += 1 episode_intervention_steps += 1
@@ -334,6 +347,10 @@ def act_with_policy(
"discrete_penalty": torch.tensor( "discrete_penalty": torch.tensor(
[new_transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)] [new_transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)]
), ),
# Forward the intervention flag so the learner can route this transition
# into the offline replay buffer (see `process_transitions` in learner.py).
# Use the plain string key so the payload survives torch.load(weights_only=True).
TeleopEvents.IS_INTERVENTION.value: is_intervention,
} }
# Create transition for learner (convert to old format) # Create transition for learner (convert to old format)
list_transition_to_send_to_learner.append( list_transition_to_send_to_learner.append(
@@ -390,14 +407,7 @@ def act_with_policy(
episode_intervention_steps = 0 episode_intervention_steps = 0
episode_total_steps = 0 episode_total_steps = 0
# Reset environment and processors transition = reset_and_build_transition(online_env, env_processor, action_processor)
obs, info = online_env.reset()
env_processor.reset()
action_processor.reset()
# Process initial observation
transition = create_transition(observation=obs, info=info)
transition = env_processor(transition)
if cfg.env.fps is not None: if cfg.env.fps is not None:
dt_time = time.perf_counter() - start_time dt_time = time.perf_counter() - start_time

View File

@@ -383,9 +383,20 @@ def make_processors(
GymHILAdapterProcessorStep(), GymHILAdapterProcessorStep(),
Numpy2TorchActionProcessorStep(), Numpy2TorchActionProcessorStep(),
VanillaObservationProcessorStep(), VanillaObservationProcessorStep(),
]
# Add time limit processor if reset config exists
if cfg.processor.reset is not None:
env_pipeline_steps.append(
TimeLimitProcessorStep(max_episode_steps=int(cfg.processor.reset.control_time_s * cfg.fps))
)
env_pipeline_steps.extend(
[
AddBatchDimensionProcessorStep(), AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=device), DeviceProcessorStep(device=device),
] ]
)
return DataProcessorPipeline( return DataProcessorPipeline(
steps=env_pipeline_steps, to_transition=identity_transition, to_output=identity_transition steps=env_pipeline_steps, to_transition=identity_transition, to_output=identity_transition
@@ -551,8 +562,19 @@ def step_env_and_process_transition(
terminated = terminated or processed_action_transition[TransitionKey.DONE] terminated = terminated or processed_action_transition[TransitionKey.DONE]
truncated = truncated or processed_action_transition[TransitionKey.TRUNCATED] truncated = truncated or processed_action_transition[TransitionKey.TRUNCATED]
complementary_data = processed_action_transition[TransitionKey.COMPLEMENTARY_DATA].copy() complementary_data = processed_action_transition[TransitionKey.COMPLEMENTARY_DATA].copy()
if hasattr(env, "get_raw_joint_positions"):
raw_joint_positions = env.get_raw_joint_positions()
if raw_joint_positions is not None:
complementary_data["raw_joint_positions"] = raw_joint_positions
# Merge env and action-processor info: env wins for str keys, action-processor
# wins for `TeleopEvents` enum keys
action_info = processed_action_transition[TransitionKey.INFO]
new_info = info.copy() new_info = info.copy()
new_info.update(processed_action_transition[TransitionKey.INFO]) for key, value in action_info.items():
if isinstance(key, TeleopEvents):
new_info[key] = value
new_transition = create_transition( new_transition = create_transition(
observation=obs, observation=obs,
@@ -568,6 +590,24 @@ def step_env_and_process_transition(
return new_transition return new_transition
def reset_and_build_transition(
env: gym.Env,
env_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
action_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
) -> EnvTransition:
"""Reset env + processors and return the first env-processed transition."""
obs, info = env.reset()
env_processor.reset()
action_processor.reset()
complementary_data: dict[str, Any] = {}
if hasattr(env, "get_raw_joint_positions"):
raw_joint_positions = env.get_raw_joint_positions()
if raw_joint_positions is not None:
complementary_data["raw_joint_positions"] = raw_joint_positions
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
return env_processor(data=transition)
def control_loop( def control_loop(
env: gym.Env, env: gym.Env,
env_processor: DataProcessorPipeline[EnvTransition, EnvTransition], env_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
@@ -593,17 +633,7 @@ def control_loop(
print("- When not intervening, robot will stay still") print("- When not intervening, robot will stay still")
print("- Press Ctrl+C to exit") print("- Press Ctrl+C to exit")
# Reset environment and processors transition = reset_and_build_transition(env, env_processor, action_processor)
obs, info = env.reset()
complementary_data = (
{"raw_joint_positions": info.pop("raw_joint_positions")} if "raw_joint_positions" in info else {}
)
env_processor.reset()
action_processor.reset()
# Process initial observation
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
transition = env_processor(data=transition)
# Determine if gripper is used # Determine if gripper is used
use_gripper = cfg.env.processor.gripper.use_gripper if cfg.env.processor.gripper is not None else True use_gripper = cfg.env.processor.gripper.use_gripper if cfg.env.processor.gripper is not None else True
@@ -665,7 +695,7 @@ def control_loop(
# Create a neutral action (no movement) # Create a neutral action (no movement)
neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32) neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)
if use_gripper: if use_gripper:
neutral_action = torch.cat([neutral_action, torch.tensor([0.0])]) # Gripper stay neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay
# Use the new step function # Use the new step function
transition = step_env_and_process_transition( transition = step_env_and_process_transition(
@@ -723,12 +753,7 @@ def control_loop(
dataset.save_episode() dataset.save_episode()
# Reset for new episode # Reset for new episode
obs, info = env.reset() transition = reset_and_build_transition(env, env_processor, action_processor)
env_processor.reset()
action_processor.reset()
transition = create_transition(observation=obs, info=info)
transition = env_processor(transition)
# Maintain fps timing # Maintain fps timing
precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0)) precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0))

View File

@@ -70,7 +70,7 @@ from lerobot.common.wandb_utils import WandBLogger
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.datasets import LeRobotDataset, make_dataset from lerobot.datasets import LeRobotDataset, make_dataset
from lerobot.policies import make_policy from lerobot.policies import make_policy, make_pre_post_processors
from lerobot.policies.sac.modeling_sac import SACPolicy from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.robots import so_follower # noqa: F401 from lerobot.robots import so_follower # noqa: F401
from lerobot.teleoperators import gamepad, so_leader # noqa: F401 from lerobot.teleoperators import gamepad, so_leader # noqa: F401
@@ -317,6 +317,11 @@ def add_actor_information_and_train(
policy.train() policy.train()
preprocessor, _postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
dataset_stats=cfg.policy.dataset_stats,
)
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
last_time_policy_pushed = time.time() last_time_policy_pushed = time.time()
@@ -405,8 +410,8 @@ def add_actor_information_and_train(
actions = batch[ACTION] actions = batch[ACTION]
rewards = batch["reward"] rewards = batch["reward"]
observations = batch["state"] observations = preprocessor.process_observation(batch["state"])
next_observations = batch["next_state"] next_observations = preprocessor.process_observation(batch["next_state"])
done = batch["done"] done = batch["done"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
@@ -463,8 +468,8 @@ def add_actor_information_and_train(
actions = batch[ACTION] actions = batch[ACTION]
rewards = batch["reward"] rewards = batch["reward"]
observations = batch["state"] observations = preprocessor.process_observation(batch["state"])
next_observations = batch["next_state"] next_observations = preprocessor.process_observation(batch["next_state"])
done = batch["done"] done = batch["done"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
@@ -1163,7 +1168,7 @@ def process_transitions(
# Add to offline buffer if it's an intervention # Add to offline buffer if it's an intervention
if dataset_repo_id is not None and transition.get("complementary_info", {}).get( if dataset_repo_id is not None and transition.get("complementary_info", {}).get(
TeleopEvents.IS_INTERVENTION TeleopEvents.IS_INTERVENTION.value
): ):
offline_replay_buffer.add(**transition) offline_replay_buffer.add(**transition)

View File

@@ -353,7 +353,8 @@ class GripperVelocityToJoint(RobotActionProcessorStep):
speed_factor: A scaling factor to convert the normalized velocity command to a position change. speed_factor: A scaling factor to convert the normalized velocity command to a position change.
clip_min: The minimum allowed gripper joint position. clip_min: The minimum allowed gripper joint position.
clip_max: The maximum allowed gripper joint position. clip_max: The maximum allowed gripper joint position.
discrete_gripper: If True, treat the input action as discrete (0: open, 1: close, 2: stay). discrete_gripper: If True, interpret the input as a discrete class index
{0 = close, 1 = stay, 2 = open}, matching `GamepadTeleop.GripperAction`.
""" """
speed_factor: float = 20.0 speed_factor: float = 20.0
@@ -377,10 +378,10 @@ class GripperVelocityToJoint(RobotActionProcessorStep):
raise ValueError("Joints observation is require for computing robot kinematics") raise ValueError("Joints observation is require for computing robot kinematics")
if self.discrete_gripper: if self.discrete_gripper:
# Discrete gripper actions are in [0, 1, 2] # Map discrete command {0=close, 1=stay, 2=open} -> signed velocity.
# 0: open, 1: close, 2: stay # Negation accounts for SO100 sign (joint position increases on close).
# We need to shift them to [-1, 0, 1] and then scale them to clip_max # 0 -> +clip_max (close), 1 -> 0 (stay), 2 -> -clip_max (open)
gripper_vel = (gripper_vel - 1) * self.clip_max gripper_vel = -(gripper_vel - 1) * self.clip_max
# Compute desired gripper position # Compute desired gripper position
delta = gripper_vel * float(self.speed_factor) delta = gripper_vel * float(self.speed_factor)

View File

@@ -150,11 +150,24 @@ Show dataset information without feature details:
--operation.type info \ --operation.type info \
--operation.show_features false --operation.show_features false
Recompute dataset statistics: Recompute dataset statistics (saves to lerobot/pusht_recomputed_stats by default):
lerobot-edit-dataset \ lerobot-edit-dataset \
--repo_id lerobot/pusht \ --repo_id lerobot/pusht \
--operation.type recompute_stats --operation.type recompute_stats
Recompute stats and save to a specific new repo_id:
lerobot-edit-dataset \
--repo_id lerobot/pusht \
--new_repo_id lerobot/pusht_new_stats \
--operation.type recompute_stats
Recompute stats in-place (overwrites original dataset stats):
lerobot-edit-dataset \
--repo_id lerobot/pusht \
--new_repo_id lerobot/pusht \
--operation.type recompute_stats \
--operation.overwrite true
Recompute stats for relative actions and push to hub: Recompute stats for relative actions and push to hub:
lerobot-edit-dataset \ lerobot-edit-dataset \
--repo_id lerobot/pusht \ --repo_id lerobot/pusht \
@@ -256,6 +269,7 @@ class RecomputeStatsConfig(OperationConfig):
relative_exclude_joints: list[str] | None = None relative_exclude_joints: list[str] | None = None
chunk_size: int = 50 chunk_size: int = 50
num_workers: int = 0 num_workers: int = 0
overwrite: bool = False
@OperationConfig.register_subclass("info") @OperationConfig.register_subclass("info")
@@ -280,16 +294,30 @@ class EditDatasetConfig:
push_to_hub: bool = False push_to_hub: bool = False
def _resolve_io_paths(
repo_id: str,
new_repo_id: str | None,
root: Path | str | None,
new_root: Path | str | None,
default_new_repo_id: str | None = None,
) -> tuple[str, Path, Path]:
"""Resolve input/output paths and repo_id for dataset operations.
Returns (output_repo_id, input_path, output_path) with resolved (symlink-safe) paths.
"""
input_path = (Path(root) if root else HF_LEROBOT_HOME / repo_id).resolve()
output_repo_id = new_repo_id or default_new_repo_id or repo_id
output_path = (Path(new_root) if new_root else HF_LEROBOT_HOME / output_repo_id).resolve()
return output_repo_id, input_path, output_path
def get_output_path( def get_output_path(
repo_id: str, repo_id: str,
new_repo_id: str | None, new_repo_id: str | None,
root: Path | str | None, root: Path | str | None,
new_root: Path | str | None, new_root: Path | str | None,
) -> tuple[str, Path]: ) -> tuple[str, Path]:
input_path = Path(root) if root else HF_LEROBOT_HOME / repo_id output_repo_id, input_path, output_path = _resolve_io_paths(repo_id, new_repo_id, root, new_root)
output_repo_id = new_repo_id if new_repo_id else repo_id
output_path = Path(new_root) if new_root else HF_LEROBOT_HOME / output_repo_id
# In case of in-place modification, create a backup of the original dataset (if it exists) # In case of in-place modification, create a backup of the original dataset (if it exists)
if output_path == input_path: if output_path == input_path:
@@ -557,7 +585,39 @@ def handle_recompute_stats(cfg: EditDatasetConfig) -> None:
if not isinstance(cfg.operation, RecomputeStatsConfig): if not isinstance(cfg.operation, RecomputeStatsConfig):
raise ValueError("Operation config must be RecomputeStatsConfig") raise ValueError("Operation config must be RecomputeStatsConfig")
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root) # Determine whether this is an in-place operation
output_repo_id, input_root, output_root = _resolve_io_paths(
cfg.repo_id,
cfg.new_repo_id,
cfg.root,
cfg.new_root,
default_new_repo_id=f"{cfg.repo_id}_recomputed_stats",
)
in_place = output_root == input_root
if in_place and not cfg.operation.overwrite:
raise ValueError(
f"recompute_stats would overwrite the dataset in-place at {input_root}. "
"Pass --operation.overwrite true to allow in-place modification, "
"or use --new_repo_id / --new_root to write to a different location. "
f"Default output repo_id when neither is set: '{cfg.repo_id}_recomputed_stats'."
)
if in_place:
logging.warning(
f"Overwriting dataset stats in-place at {input_root}. The original stats will be lost."
)
dataset = LeRobotDataset(cfg.repo_id, root=input_root)
else:
logging.info(f"Copying dataset from {input_root} to {output_root}")
if output_root.exists():
backup_path = output_root.with_name(output_root.name + "_old")
logging.warning(f"Output directory {output_root} already exists. Moving to {backup_path}")
if backup_path.exists():
shutil.rmtree(backup_path)
shutil.move(output_root, backup_path)
shutil.copytree(input_root, output_root)
dataset = LeRobotDataset(output_repo_id, root=output_root)
logging.info(f"Recomputing stats for {cfg.repo_id}") logging.info(f"Recomputing stats for {cfg.repo_id}")
if cfg.operation.relative_action: if cfg.operation.relative_action:
@@ -578,7 +638,7 @@ def handle_recompute_stats(cfg: EditDatasetConfig) -> None:
logging.info(f"Stats written to {dataset.root}") logging.info(f"Stats written to {dataset.root}")
if cfg.push_to_hub: if cfg.push_to_hub:
logging.info(f"Pushing to hub as {dataset.meta.repo_id}...") logging.info(f"Pushing to hub as {dataset.repo_id}...")
dataset.push_to_hub() dataset.push_to_hub()

View File

@@ -115,7 +115,9 @@ _feetech_sdk_available = is_package_available("feetech-servo-sdk", import_name="
_reachy2_sdk_available = is_package_available("reachy2_sdk") _reachy2_sdk_available = is_package_available("reachy2_sdk")
_can_available = is_package_available("python-can", "can") _can_available = is_package_available("python-can", "can")
_unitree_sdk_available = is_package_available("unitree-sdk2py", "unitree_sdk2py") _unitree_sdk_available = is_package_available("unitree-sdk2py", "unitree_sdk2py")
_pyrealsense2_available = is_package_available("pyrealsense2") _pyrealsense2_available = is_package_available("pyrealsense2") or is_package_available(
"pyrealsense2-macosx", import_name="pyrealsense2"
)
_zmq_available = is_package_available("pyzmq", import_name="zmq") _zmq_available = is_package_available("pyzmq", import_name="zmq")
_hebi_available = is_package_available("hebi-py", import_name="hebi") _hebi_available = is_package_available("hebi-py", import_name="hebi")
_teleop_available = is_package_available("teleop") _teleop_available = is_package_available("teleop")