Compare commits

..

3 Commits

Author SHA1 Message Date
Steven Palma
d0ae3e9481 refactor(dataset): update imports across the codebase 2026-03-15 23:09:52 -07:00
Steven Palma
26d732c8c8 refactor(dataset): modular files 2026-03-15 23:07:52 -07:00
Steven Palma
c7458c67cd chore(dataset): basic house-keeping 2026-03-15 21:26:58 -07:00
12 changed files with 236 additions and 622 deletions

View File

@@ -19,8 +19,6 @@
title: Multi GPU training
- local: peft_training
title: Training with PEFT (e.g., LoRA)
- local: rename_map
title: Using Rename Map and Empty Cameras
title: "Tutorials"
- sections:
- local: lerobot-dataset-v3

View File

@@ -310,4 +310,4 @@ Asynchronous inference represents a significant advancement in real-time robotic
- **Universal Compatibility**: Works with all LeRobot-supported policies, from lightweight ACT models to vision-language models like SmolVLA
Start experimenting with the default parameters, monitor your action queue sizes, and iteratively refine your setup to achieve optimal performance for your specific use case.
If you want to discuss this further, hop into our [Discord community](https://discord.gg/s3KuuzsPFb), or open an issue on our [GitHub repository](https://github.com/huggingface/lerobot/issues).
If you want to discuss this further, hop into our [Discord community](https://discord.gg/s3KuuzsPFb), or open an issue on our [GitHub repository](https://github.com/lerobot/lerobot/issues).

View File

@@ -204,26 +204,22 @@ Replace `your_username/dataset_name` with your Hugging Face username and a name
Your dataset includes:
**Your Actions (2 features)**:
**Your Actions (2 things)**:
- `linear_velocity`: How much you moved forward/backward
- `angular_velocity`: How much you turned left/right
- How much you moved forward/backward
- How much you turned left/right
**Robot Observations (24 features)**:
**Robot Observations (12 things)**:
- Front camera video
- Rear camera video
- Current speed
- Battery level
- Orientation
- GPS (latitude, longitude, signal strength)
- Which way the robot is facing
- GPS location (latitude, longitude, signal strength)
- Network signal strength
- Vibration level
- Lamp state (on/off)
- Accelerometer (x, y, z)
- Gyroscope (x, y, z)
- Magnetometer (x, y, z)
- Wheel RPMs (4 wheels)
- Lamp status (on/off)
### Where Your Data Goes

View File

@@ -1,114 +0,0 @@
# Rename Map and Empty Cameras
When you train, evaluate, or record with a robot policy, your **dataset** or **environment** provides observations under one set of keys (e.g. `observation.images.front`, `observation.images.eagle`), while your **policy** expects another (e.g. `observation.images.image`, `observation.images.image2`). The **rename map** bridges that gap without changing the policy or data source.
> **Scope:** The rename map only renames **observation** keys (images and state). Action keys are not affected.
## Why observation keys don't always match
Policies have a fixed set of **input feature names** baked into their pretrained config. For example:
- [pi0fast-libero](https://huggingface.co/lerobot/pi0fast-libero) expects `observation.images.base_0_rgb` and `observation.images.left_wrist_0_rgb`.
- [xvla-base](https://huggingface.co/lerobot/xvla-base) expects `observation.images.image`, `observation.images.image2`, and `observation.images.image3`.
Your dataset might use different names entirely (e.g. `observation.images.front`, `observation.images.eagle`, `observation.images.glove`), and your eval environment might use yet another set. Rather than editing the policy config or renaming columns in the dataset, you pass a **rename map**: a JSON dictionary that maps source keys to the keys the policy expects. Renaming happens inside the preprocessor pipeline, so the policy always sees its expected keys.
## Using the rename map
Pass the mapping as a JSON string on the command line. The convention is always:
```
--rename_map='{"source_key": "policy_key", ...}'
```
where **source_key** is what the dataset or environment provides, and **policy_key** is what the policy expects.
Only listed keys are renamed; everything else passes through unchanged. Order of entries doesn't matter.
Supported policies: **PI0**, **PI05**, **PI0Fast**, **SmolVLA**, and **XVLA**.
### Training
Suppose you fine-tune [lerobot/xvla-base](https://huggingface.co/lerobot/xvla-base) on a dataset with images under `observation.images.front`, `observation.images.eagle`, and `observation.images.glove`. XVLA expects `observation.images.image`, `observation.images.image2`, and `observation.images.image3`:
```bash
lerobot-train \
--dataset.repo_id=YOUR_DATASET \
--output_dir=./outputs/xvla_training \
--job_name=xvla_training \
--policy.path="lerobot/xvla-base" \
--policy.repo_id="HF_USER/xvla-your-robot" \
--policy.dtype=bfloat16 \
--policy.action_mode=auto \
--steps=20000 \
--policy.device=cuda \
--policy.freeze_vision_encoder=false \
--policy.freeze_language_encoder=false \
--policy.train_policy_transformer=true \
--policy.train_soft_prompts=true \
--rename_map='{"observation.images.front": "observation.images.image", "observation.images.eagle": "observation.images.image2", "observation.images.glove": "observation.images.image3"}'
```
### Evaluation
A policy that expects `observation.images.base_0_rgb` and `observation.images.left_wrist_0_rgb` (e.g. [pi0fast-libero](https://huggingface.co/lerobot/pi0fast-libero)), but the LIBERO environment returns `observation.images.image` and `observation.images.image2`:
```bash
lerobot-eval \
--policy.path=lerobot/pi0fast-libero \
--env.type=libero \
... \
--rename_map='{"observation.images.image": "observation.images.base_0_rgb", "observation.images.image2": "observation.images.left_wrist_0_rgb"}'
```
### Recording
`lerobot-record` also supports rename maps, nested under the dataset config:
```bash
lerobot-record \ # When running inference
--policy.path="<user>/smolVLA_finetuned" \
... \
--dataset.rename_map='{"observation.images.glove2": "observation.images.image"}'
```
## Alternative: edit the policy config directly
If you always use the same dataset or environment, you can **edit the policy's `config.json`** so its observation keys match your data source. Then no rename map is needed.
The tradeoff: modifying the policy config ties it to one data source. A rename map keeps one policy usable across many datasets and environments.
## Empty cameras: fewer views than the policy expects
Some policies are built for a fixed number of image inputs. If your dataset has fewer cameras, you can set **`empty_cameras`** in the policy config instead of modifying the model architecture.
### How it works
Setting `empty_cameras=N` adds N placeholder image features to the policy config, named:
```
observation.images.empty_camera_0
observation.images.empty_camera_1
...
```
At runtime, these keys have no corresponding data in the batch. The policy fills them with masked dummy tensors (padded with `-1` for SigLIP-based vision encoders, with a zero attention mask), so the extra image slots are effectively ignored during training and inference.
### Example
XVLA-base has three visual inputs and `empty_cameras=0` by default. Your dataset only has two cameras:
1. Set `--policy.empty_cameras=1`.
2. The config adds a third key: `observation.images.empty_camera_0`.
3. Use the rename map for your two real cameras as usual.
4. The third slot is masked out — no fake images needed in your dataset.
## Quick reference
| Goal | What to do |
| ----------------------------------------- | --------------------------------------------------------------------------- |
| Dataset keys ≠ policy keys | `--rename_map='{"dataset_key": "policy_key", ...}'` |
| Env keys ≠ policy keys (eval) | `--rename_map='{"env_key": "policy_key", ...}'` |
| Recording with different keys (inference) | `--dataset.rename_map='{"source_key": "policy_key", ...}'`. |
| Fewer cameras than policy expects | `--policy.empty_cameras=N` (supported by PI0, PI05, PI0Fast, SmolVLA, XVLA) |
| Avoid passing a rename map | Edit the policy's `config.json` so its keys match your data source |

View File

@@ -467,8 +467,8 @@ class VQBeTHead(nn.Module):
self.vqvae_model.optimized_steps += 1
# if we updated RVQ more than `n_vqvae_training_steps` steps, we freeze the RVQ part.
if self.vqvae_model.optimized_steps >= n_vqvae_training_steps:
self.vqvae_model.discretized.fill_(True)
self.vqvae_model.vq_layer.freeze_codebook.fill_(True)
self.vqvae_model.discretized = torch.tensor(True)
self.vqvae_model.vq_layer.freeze_codebook = torch.tensor(True)
print("Finished discretizing action data!")
self.vqvae_model.eval()
for param in self.vqvae_model.vq_layer.parameters():

View File

@@ -131,15 +131,6 @@ class _NormalizationMixin:
if self.dtype is None:
self.dtype = torch.float32
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
self._reshape_visual_stats()
def _reshape_visual_stats(self) -> None:
"""Reshape visual stats from ``[C]`` to ``[C, 1, 1]`` for image broadcasting."""
for key, feature in self.features.items():
if feature.type == FeatureType.VISUAL and key in self._tensor_stats:
for stat_name, stat_tensor in self._tensor_stats[key].items():
if isinstance(stat_tensor, Tensor) and stat_tensor.ndim == 1:
self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1)
def to(
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
@@ -158,7 +149,6 @@ class _NormalizationMixin:
if dtype is not None:
self.dtype = dtype
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
self._reshape_visual_stats()
return self
def state_dict(self) -> dict[str, Tensor]:
@@ -208,7 +198,6 @@ class _NormalizationMixin:
# Don't load from state_dict, keep the explicitly provided stats
# But ensure _tensor_stats is properly initialized
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment]
self._reshape_visual_stats()
return
# Normal behavior: load stats from state_dict
@@ -220,8 +209,6 @@ class _NormalizationMixin:
dtype=torch.float32, device=self.device
)
self._reshape_visual_stats()
# Reconstruct the original stats dict from tensor stats for compatibility with to() method
# and other functions that rely on self.stats
self.stats = {}

View File

@@ -62,7 +62,6 @@ from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.policies.factory import make_policy
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
from lerobot.rl.process import ProcessSignalHandler
from lerobot.rl.queue import get_last_item_from_queue
from lerobot.robots import so_follower # noqa: F401
@@ -259,11 +258,6 @@ def act_with_policy(
policy = policy.eval()
assert isinstance(policy, nn.Module)
preprocessor, postprocessor = make_sac_pre_post_processors(
config=cfg.policy,
dataset_stats=cfg.policy.dataset_stats,
)
obs, info = online_env.reset()
env_processor.reset()
action_processor.reset()
@@ -295,9 +289,7 @@ def act_with_policy(
# Time policy inference and check if it meets FPS requirement
with policy_timer:
# Extract observation from transition for policy
normalized_observation = preprocessor.process_observation(observation)
action = policy.select_action(batch=normalized_observation)
# action = postprocessor.process_action(action)
action = policy.select_action(batch=observation)
policy_fps = policy_timer.fps_last
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)

View File

@@ -66,7 +66,6 @@ from lerobot.datasets.factory import make_dataset
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.factory import make_policy
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions
from lerobot.rl.process import ProcessSignalHandler
from lerobot.rl.wandb_utils import WandBLogger
@@ -314,11 +313,6 @@ def add_actor_information_and_train(
assert isinstance(policy, nn.Module)
preprocessor, _ = make_sac_pre_post_processors(
config=cfg.policy,
dataset_stats=cfg.policy.dataset_stats,
)
policy.train()
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
@@ -414,9 +408,6 @@ def add_actor_information_and_train(
done = batch["done"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
observations = preprocessor.process_observation(observations)
next_observations = preprocessor.process_observation(next_observations)
observation_features, next_observation_features = get_observation_features(
policy=policy, observations=observations, next_observations=next_observations
)
@@ -476,9 +467,6 @@ def add_actor_information_and_train(
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
observations = preprocessor.process_observation(observations)
next_observations = preprocessor.process_observation(next_observations)
observation_features, next_observation_features = get_observation_features(
policy=policy, observations=observations, next_observations=next_observations
)

View File

@@ -33,40 +33,21 @@ from .config_earthrover_mini_plus import EarthRoverMiniPlusConfig
logger = logging.getLogger(__name__)
# Action feature keys
ACTION_LINEAR_VEL = "linear_velocity"
ACTION_ANGULAR_VEL = "angular_velocity"
ACTION_LINEAR_VEL = "linear.vel"
ACTION_ANGULAR_VEL = "angular.vel"
# Observation feature keys — cameras
# Observation feature keys
OBS_FRONT = "front"
OBS_REAR = "rear"
# Observation feature keys — telemetry
OBS_SPEED = "speed"
OBS_BATTERY_LEVEL = "battery_level"
OBS_ORIENTATION = "orientation"
OBS_GPS_LATITUDE = "gps_latitude"
OBS_GPS_LONGITUDE = "gps_longitude"
OBS_GPS_SIGNAL = "gps_signal"
OBS_SIGNAL_LEVEL = "signal_level"
OBS_LINEAR_VEL = "linear.vel"
OBS_BATTERY_LEVEL = "battery.level"
OBS_ORIENTATION_DEG = "orientation.deg"
OBS_GPS_LATITUDE = "gps.latitude"
OBS_GPS_LONGITUDE = "gps.longitude"
OBS_GPS_SIGNAL = "gps.signal"
OBS_SIGNAL_LEVEL = "signal.level"
OBS_VIBRATION = "vibration"
OBS_LAMP = "lamp"
# Observation feature keys — IMU sensors
OBS_ACCELEROMETER_X = "accelerometer_x"
OBS_ACCELEROMETER_Y = "accelerometer_y"
OBS_ACCELEROMETER_Z = "accelerometer_z"
OBS_GYROSCOPE_X = "gyroscope_x"
OBS_GYROSCOPE_Y = "gyroscope_y"
OBS_GYROSCOPE_Z = "gyroscope_z"
OBS_MAGNETOMETER_X = "magnetometer_filtered_x"
OBS_MAGNETOMETER_Y = "magnetometer_filtered_y"
OBS_MAGNETOMETER_Z = "magnetometer_filtered_z"
# Observation feature keys — wheel RPMs
OBS_WHEEL_RPM_0 = "wheel_rpm_0"
OBS_WHEEL_RPM_1 = "wheel_rpm_1"
OBS_WHEEL_RPM_2 = "wheel_rpm_2"
OBS_WHEEL_RPM_3 = "wheel_rpm_3"
OBS_LAMP_STATE = "lamp.state"
class EarthRoverMiniPlus(Robot):
@@ -173,60 +154,33 @@ class EarthRoverMiniPlus(Robot):
dict: Observation features with types/shapes:
- front: (480, 640, 3) - Front camera RGB image
- rear: (480, 640, 3) - Rear camera RGB image
- speed: float - Current speed (raw SDK value)
- battery_level: float - Battery level (0-100)
- orientation: float - Robot orientation in degrees
- gps_latitude: float - GPS latitude coordinate
- gps_longitude: float - GPS longitude coordinate
- gps_signal: float - GPS signal strength (percentage)
- signal_level: float - Network signal level (0-5)
- linear.vel: float - Current speed (0-1, SDK reports only positive speeds)
- battery.level: float - Battery level (0-1, normalized from 0-100)
- orientation.deg: float - Robot orientation (0-1, normalized from raw value)
- gps.latitude: float - GPS latitude coordinate
- gps.longitude: float - GPS longitude coordinate
- gps.signal: float - GPS signal strength (0-1, normalized from percentage)
- signal.level: float - Network signal level (0-1, normalized from 0-5)
- vibration: float - Vibration sensor reading
- lamp: float - Lamp state (0=off, 1=on)
- accelerometer_x: float - Accelerometer X axis (raw SDK value)
- accelerometer_y: float - Accelerometer Y axis (raw SDK value)
- accelerometer_z: float - Accelerometer Z axis (raw SDK value)
- gyroscope_x: float - Gyroscope X axis (raw SDK value)
- gyroscope_y: float - Gyroscope Y axis (raw SDK value)
- gyroscope_z: float - Gyroscope Z axis (raw SDK value)
- magnetometer_filtered_x: float - Magnetometer X axis (raw SDK value)
- magnetometer_filtered_y: float - Magnetometer Y axis (raw SDK value)
- magnetometer_filtered_z: float - Magnetometer Z axis (raw SDK value)
- wheel_rpm_0: float - Wheel 0 RPM
- wheel_rpm_1: float - Wheel 1 RPM
- wheel_rpm_2: float - Wheel 2 RPM
- wheel_rpm_3: float - Wheel 3 RPM
- lamp.state: float - Lamp state (0=off, 1=on)
"""
return {
# Cameras (height, width, channels)
OBS_FRONT: (480, 640, 3),
OBS_REAR: (480, 640, 3),
# Telemetry
OBS_SPEED: float,
# Motion state
OBS_LINEAR_VEL: float,
# Robot state
OBS_BATTERY_LEVEL: float,
OBS_ORIENTATION: float,
OBS_ORIENTATION_DEG: float,
# GPS
OBS_GPS_LATITUDE: float,
OBS_GPS_LONGITUDE: float,
OBS_GPS_SIGNAL: float,
# Sensors
OBS_SIGNAL_LEVEL: float,
OBS_VIBRATION: float,
OBS_LAMP: float,
# IMU — accelerometer
OBS_ACCELEROMETER_X: float,
OBS_ACCELEROMETER_Y: float,
OBS_ACCELEROMETER_Z: float,
# IMU — gyroscope
OBS_GYROSCOPE_X: float,
OBS_GYROSCOPE_Y: float,
OBS_GYROSCOPE_Z: float,
# IMU — magnetometer
OBS_MAGNETOMETER_X: float,
OBS_MAGNETOMETER_Y: float,
OBS_MAGNETOMETER_Z: float,
# Wheel RPMs
OBS_WHEEL_RPM_0: float,
OBS_WHEEL_RPM_1: float,
OBS_WHEEL_RPM_2: float,
OBS_WHEEL_RPM_3: float,
OBS_LAMP_STATE: float,
}
@cached_property
@@ -235,8 +189,8 @@ class EarthRoverMiniPlus(Robot):
Returns:
dict: Action features with types:
- linear_velocity: float - Target linear velocity (-1 to 1)
- angular_velocity: float - Target angular velocity (-1 to 1)
- linear.vel: float - Target linear velocity
- angular.vel: float - Target angular velocity
"""
return {
ACTION_LINEAR_VEL: float,
@@ -247,29 +201,19 @@ class EarthRoverMiniPlus(Robot):
def get_observation(self) -> RobotObservation:
"""Get current robot observation from SDK.
Camera frames are retrieved from SDK endpoints /v2/front and /v2/rear.
Frames are decoded from base64 and converted from BGR to RGB format.
Robot telemetry is retrieved from /data endpoint.
Sensor arrays (accels, gyros, mags, rpms) each contain entries of
[values..., timestamp]; the latest reading from each array is used.
Returns:
RobotObservation: Observation containing:
- front: Front camera image (480, 640, 3) in RGB format
- rear: Rear camera image (480, 640, 3) in RGB format
- speed: float - Current speed (raw SDK value)
- battery_level: float - Battery level (0-100)
- orientation: float - Robot orientation in degrees
- gps_latitude: float - GPS latitude coordinate
- gps_longitude: float - GPS longitude coordinate
- gps_signal: float - GPS signal strength (percentage)
- signal_level: float - Network signal level (0-5)
- vibration: float - Vibration sensor reading
- lamp: float - Lamp state (0=off, 1=on)
- accelerometer_x/y/z: float - Accelerometer axes (raw SDK value)
- gyroscope_x/y/z: float - Gyroscope axes (raw SDK value)
- magnetometer_filtered_x/y/z: float - Magnetometer axes (raw SDK value)
- wheel_rpm_0/1/2/3: float - Wheel RPMs
- linear.vel: Current speed (0-1, SDK reports only positive speeds)
- battery.level: Battery level (0-1, normalized from 0-100)
- orientation.deg: Robot orientation (0-1, normalized from raw value)
- gps.latitude: GPS latitude coordinate
- gps.longitude: GPS longitude coordinate
- gps.signal: GPS signal strength (0-1, normalized from percentage)
- signal.level: Network signal level (0-1, normalized from 0-5)
- vibration: Vibration sensor reading
- lamp.state: Lamp state (0=off, 1=on)
Raises:
DeviceNotConnectedError: If robot is not connected
@@ -291,41 +235,22 @@ class EarthRoverMiniPlus(Robot):
# Get robot state from SDK
robot_data = self._get_robot_data()
# Telemetry
observation[OBS_SPEED] = float(robot_data["speed"])
observation[OBS_BATTERY_LEVEL] = float(robot_data["battery"])
observation[OBS_ORIENTATION] = float(robot_data["orientation"])
observation[OBS_GPS_LATITUDE] = float(robot_data["latitude"])
observation[OBS_GPS_LONGITUDE] = float(robot_data["longitude"])
observation[OBS_GPS_SIGNAL] = float(robot_data["gps_signal"])
observation[OBS_SIGNAL_LEVEL] = float(robot_data["signal_level"])
observation[OBS_VIBRATION] = float(robot_data["vibration"])
observation[OBS_LAMP] = float(robot_data["lamp"])
# Motion state
observation[OBS_LINEAR_VEL] = robot_data["speed"] / 100.0 # Normalize 0-100 to 0-1
# Accelerometer — latest reading from accels array [x, y, z, ts]
accel = self._latest_sensor_reading(robot_data, "accels", n_values=3)
observation[OBS_ACCELEROMETER_X] = accel[0]
observation[OBS_ACCELEROMETER_Y] = accel[1]
observation[OBS_ACCELEROMETER_Z] = accel[2]
# Robot state
observation[OBS_BATTERY_LEVEL] = robot_data["battery"] / 100.0 # Normalize 0-100 to 0-1
observation[OBS_ORIENTATION_DEG] = robot_data["orientation"] / 360.0 # Normalize to 0-1
# Gyroscope — latest reading from gyros array [x, y, z, ts]
gyro = self._latest_sensor_reading(robot_data, "gyros", n_values=3)
observation[OBS_GYROSCOPE_X] = gyro[0]
observation[OBS_GYROSCOPE_Y] = gyro[1]
observation[OBS_GYROSCOPE_Z] = gyro[2]
# GPS data
observation[OBS_GPS_LATITUDE] = robot_data["latitude"]
observation[OBS_GPS_LONGITUDE] = robot_data["longitude"]
observation[OBS_GPS_SIGNAL] = robot_data["gps_signal"] / 100.0 # Normalize percentage to 0-1
# Magnetometer — latest reading from mags array [x, y, z, ts]
mag = self._latest_sensor_reading(robot_data, "mags", n_values=3)
observation[OBS_MAGNETOMETER_X] = mag[0]
observation[OBS_MAGNETOMETER_Y] = mag[1]
observation[OBS_MAGNETOMETER_Z] = mag[2]
# Wheel RPMs — latest reading from rpms array [w0, w1, w2, w3, ts]
rpm = self._latest_sensor_reading(robot_data, "rpms", n_values=4)
observation[OBS_WHEEL_RPM_0] = rpm[0]
observation[OBS_WHEEL_RPM_1] = rpm[1]
observation[OBS_WHEEL_RPM_2] = rpm[2]
observation[OBS_WHEEL_RPM_3] = rpm[3]
# Sensors
observation[OBS_SIGNAL_LEVEL] = robot_data["signal_level"] / 5.0 # Normalize 0-5 to 0-1
observation[OBS_VIBRATION] = robot_data["vibration"]
observation[OBS_LAMP_STATE] = float(robot_data["lamp"]) # 0 or 1
return observation
@@ -335,12 +260,11 @@ class EarthRoverMiniPlus(Robot):
Args:
action: Action dict with keys:
- linear_velocity: Target linear velocity (-1 to 1)
- angular_velocity: Target angular velocity (-1 to 1)
- linear.vel: Target linear velocity (-1 to 1)
- angular.vel: Target angular velocity (-1 to 1)
Returns:
RobotAction: The action that was sent (matches action_features keys)
Raises:
DeviceNotConnectedError: If robot is not connected
@@ -348,14 +272,18 @@ class EarthRoverMiniPlus(Robot):
Actions are sent to SDK via POST /control endpoint.
SDK expects commands in range [-1, 1].
"""
# Extract action values and convert to float
linear = float(action.get(ACTION_LINEAR_VEL, 0.0))
angular = float(action.get(ACTION_ANGULAR_VEL, 0.0))
# Send command to SDK
try:
self._send_command_to_sdk(linear, angular)
except Exception as e:
logger.error(f"Error sending action: {e}")
# Return action in format matching action_features
return {
ACTION_LINEAR_VEL: linear,
ACTION_ANGULAR_VEL: angular,
@@ -466,27 +394,11 @@ class EarthRoverMiniPlus(Robot):
logger.error(f"Error decoding image: {e}")
return None
@staticmethod
def _latest_sensor_reading(robot_data: dict, key: str, n_values: int) -> list[float]:
"""Extract the latest sensor reading from an SDK sensor array.
The SDK returns sensor arrays like ``accels``, ``gyros``, ``mags``,
``rpms`` where each entry is ``[value_0, ..., value_n, timestamp]``.
This helper returns the *n_values* leading floats from the last entry,
falling back to zeros when the key is missing or the array is empty.
"""
readings = robot_data.get(key)
if readings and len(readings) > 0:
latest = readings[-1]
return [float(v) for v in latest[:n_values]]
return [0.0] * n_values
def _get_robot_data(self) -> dict:
"""Get robot telemetry data from SDK.
Returns:
dict: Robot telemetry data including battery, speed, orientation, GPS,
and sensor arrays (accels, gyros, mags, rpms):
dict: Robot telemetry data including battery, speed, orientation, GPS, etc:
- Current data (if request succeeds)
- Cached data (if request fails but cache exists)
- Default values (if request fails and no cache exists yet)
@@ -508,23 +420,19 @@ class EarthRoverMiniPlus(Robot):
# Fallback: use cache or default values
if self._last_robot_data is not None:
return self._last_robot_data
# Return dict with default values (used only on first failure before any cache exists)
return {
"speed": 0,
"battery": 0,
"orientation": 0,
"latitude": 0.0,
"longitude": 0.0,
"gps_signal": 0,
"signal_level": 0,
"vibration": 0.0,
"lamp": 0,
"accels": [],
"gyros": [],
"mags": [],
"rpms": [],
}
else:
# Return dict with default values (used only on first failure before any cache exists)
return {
"speed": 0,
"battery": 0,
"orientation": 0,
"latitude": 0.0,
"longitude": 0.0,
"gps_signal": 0,
"signal_level": 0,
"vibration": 0.0,
"lamp": 0,
}
def _send_command_to_sdk(self, linear: float, angular: float, lamp: int = 0) -> bool:
"""Send control command to SDK.

View File

@@ -23,46 +23,65 @@ class InputController:
"""Base class for input controllers that generate motion deltas."""
def __init__(self, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0):
"""
Initialize the controller.
Args:
x_step_size: Base movement step size in meters
y_step_size: Base movement step size in meters
z_step_size: Base movement step size in meters
"""
self.x_step_size = x_step_size
self.y_step_size = y_step_size
self.z_step_size = z_step_size
self.running = True
self.episode_end_status = None
self.episode_end_status = None # None, "success", or "failure"
self.intervention_flag = False
self.open_gripper_command = False
self.close_gripper_command = False
def start(self):
"""Start the controller and initialize resources."""
pass
def stop(self):
pass
def reset(self):
"""Stop the controller and release resources."""
pass
def get_deltas(self):
"""Get the current movement deltas (dx, dy, dz) in meters."""
return 0.0, 0.0, 0.0
def update(self):
"""Update controller state - call this once per frame."""
pass
def __enter__(self):
"""Support for use in 'with' statements."""
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Ensure resources are released when exiting 'with' block."""
self.stop()
def get_episode_end_status(self):
"""
Get the current episode end status.
Returns:
None if episode should continue, "success" or "failure" otherwise
"""
status = self.episode_end_status
self.episode_end_status = None
self.episode_end_status = None # Reset after reading
return status
def should_intervene(self):
"""Return True if intervention flag was set."""
return self.intervention_flag
def gripper_command(self):
"""Return the current gripper command."""
if self.open_gripper_command == self.close_gripper_command:
return "stay"
elif self.open_gripper_command:
@@ -83,14 +102,14 @@ class KeyboardController(InputController):
"backward_y": False,
"forward_z": False,
"backward_z": False,
"quit": False,
"success": False,
"failure": False,
"intervention": False,
"rerecord": False,
}
self.listener = None
def start(self):
"""Start the keyboard listener."""
from pynput import keyboard
def on_press(key):
@@ -107,21 +126,16 @@ class KeyboardController(InputController):
self.key_states["backward_z"] = True
elif key == keyboard.Key.shift_r:
self.key_states["forward_z"] = True
elif key == keyboard.Key.ctrl_r:
self.open_gripper_command = True
elif key == keyboard.Key.ctrl_l:
self.close_gripper_command = True
elif key == keyboard.Key.esc:
self.key_states["quit"] = True
self.running = False
return False
elif key == keyboard.Key.enter:
self.key_states["success"] = True
self.episode_end_status = TeleopEvents.SUCCESS
elif key == keyboard.Key.esc:
elif key == keyboard.Key.backspace:
self.key_states["failure"] = True
self.episode_end_status = TeleopEvents.FAILURE
elif key == keyboard.Key.space:
self.key_states["intervention"] = not self.key_states["intervention"]
elif hasattr(key, "char") and key.char == "r":
self.key_states["rerecord"] = True
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
except AttributeError:
pass
@@ -139,10 +153,10 @@ class KeyboardController(InputController):
self.key_states["backward_z"] = False
elif key == keyboard.Key.shift_r:
self.key_states["forward_z"] = False
elif key == keyboard.Key.ctrl_r:
self.open_gripper_command = False
elif key == keyboard.Key.ctrl_l:
self.close_gripper_command = False
elif key == keyboard.Key.enter:
self.key_states["success"] = False
elif key == keyboard.Key.backspace:
self.key_states["failure"] = False
except AttributeError:
pass
@@ -151,18 +165,18 @@ class KeyboardController(InputController):
print("Keyboard controls:")
print(" Arrow keys: Move in X-Y plane")
print(" Shift / Shift_R: Move in Z axis")
print(" Ctrl_R / Ctrl_L: Open / Close gripper")
print(" Space: Toggle intervention")
print(" Shift and Shift_R: Move in Z axis")
print(" Enter: End episode with SUCCESS")
print(" Esc: End episode with FAILURE")
print(" R: Rerecord episode")
print(" Backspace: End episode with FAILURE")
print(" ESC: Exit")
def stop(self):
"""Stop the keyboard listener."""
if self.listener and self.listener.is_alive():
self.listener.stop()
def get_deltas(self):
"""Get the current movement deltas from keyboard state."""
delta_x = delta_y = delta_z = 0.0
if self.key_states["forward_x"]:
@@ -180,58 +194,18 @@ class KeyboardController(InputController):
return delta_x, delta_y, delta_z
def should_intervene(self):
return self.key_states["intervention"]
def reset(self):
for key in self.key_states:
self.key_states[key] = False
class GamepadController(InputController):
"""Generate motion deltas from gamepad input using pygame.
Matches gym-hil button/axis conventions for Linux gamepads, including
Xbox mappings.
"""
# Face buttons (same across most controllers on Linux)
BUTTON_A = 0
BUTTON_B = 1
BUTTON_X = 2
BUTTON_Y = 3
BUTTON_LB = 4
BUTTON_RB = 5
# Stick axes
AXIS_LEFT_X = 0
AXIS_LEFT_Y = 1
AXIS_RIGHT_X = 2
AXIS_RIGHT_Y = 3
# Default trigger buttons
BUTTON_LT = 6
BUTTON_RT = 7
# Xbox (gym-hil mapping on Linux)
XBOX_BUTTON_LT = 9
XBOX_BUTTON_RT = 10
"""Generate motion deltas from gamepad input."""
def __init__(self, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0, deadzone=0.1):
super().__init__(x_step_size, y_step_size, z_step_size)
self.deadzone = deadzone
self.joystick = None
self.intervention_flag = False
self.is_xbox = False
self._xbox360_profile = False
self._invert_left_x = False
self._invert_left_y = True
self._invert_right_y = True
def _detect_xbox(self, name):
name_lower = name.lower()
return any(tag in name_lower for tag in ["xbox", "microsoft", "x-box"])
def start(self):
"""Initialize pygame and the gamepad."""
import pygame
pygame.init()
@@ -244,35 +218,18 @@ class GamepadController(InputController):
self.joystick = pygame.joystick.Joystick(0)
self.joystick.init()
joystick_name = self.joystick.get_name()
self.is_xbox = self._detect_xbox(joystick_name)
self._xbox360_profile = joystick_name == "Xbox 360 Controller"
if self._xbox360_profile:
# gym-hil "Xbox 360 Controller" profile
self.AXIS_RIGHT_X = 3
self.AXIS_RIGHT_Y = 4
self.BUTTON_LT = self.XBOX_BUTTON_LT
self.BUTTON_RT = self.XBOX_BUTTON_RT
self._invert_left_x = True
else:
# gym-hil default profile
self.AXIS_RIGHT_X = 2
self.AXIS_RIGHT_Y = 3
self.BUTTON_LT = 6
self.BUTTON_RT = 7
self._invert_left_x = False
logging.info(f"Initialized gamepad: {joystick_name} (xbox={self.is_xbox})")
logging.info(f"Initialized gamepad: {self.joystick.get_name()}")
print("Gamepad controls:")
print(" Left analog stick: Move in X-Y plane")
print(" Right analog stick (vertical): Move in Z axis")
print(" RB: Intervention toggle")
print(" LT / RT: Close / Open gripper")
print(" Y: End episode with SUCCESS")
print(" A: End episode with FAILURE")
print(" X: Rerecord episode")
print(" B/Circle button: Exit")
print(" Y/Triangle button: End episode with SUCCESS")
print(" A/Cross button: End episode with FAILURE")
print(" X/Square button: Rerecord episode")
def stop(self):
"""Clean up pygame resources."""
import pygame
if pygame.joystick.get_init():
@@ -282,56 +239,67 @@ class GamepadController(InputController):
pygame.quit()
def update(self):
"""Process pygame events to get fresh gamepad readings."""
import pygame
for event in pygame.event.get():
if event.type == pygame.JOYBUTTONDOWN:
if event.button == self.BUTTON_Y:
if event.button == 3:
self.episode_end_status = TeleopEvents.SUCCESS
elif event.button == self.BUTTON_A:
# A button (1) for failure
elif event.button == 1:
self.episode_end_status = TeleopEvents.FAILURE
elif event.button == self.BUTTON_X:
# X button (0) for rerecord
elif event.button == 0:
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
elif event.button == self.BUTTON_LT:
# RB button (6) for closing gripper
elif event.button == 6:
self.close_gripper_command = True
elif event.button == self.BUTTON_RT:
# LT button (7) for opening gripper
elif event.button == 7:
self.open_gripper_command = True
# Reset episode status on button release
elif event.type == pygame.JOYBUTTONUP:
if event.button in [self.BUTTON_Y, self.BUTTON_A, self.BUTTON_X]:
if event.button in [0, 2, 3]:
self.episode_end_status = None
elif event.button == self.BUTTON_LT:
elif event.button == 6:
self.close_gripper_command = False
elif event.button == self.BUTTON_RT:
elif event.button == 7:
self.open_gripper_command = False
if self.joystick.get_button(self.BUTTON_RB):
# Check for RB button (typically button 5) for intervention flag
if self.joystick.get_button(5):
self.intervention_flag = True
else:
self.intervention_flag = False
def get_deltas(self):
"""Get the current movement deltas from gamepad state."""
import pygame
try:
x_input = self.joystick.get_axis(self.AXIS_LEFT_X)
y_input = self.joystick.get_axis(self.AXIS_LEFT_Y)
z_input = self.joystick.get_axis(self.AXIS_RIGHT_Y)
# Read joystick axes
# Left stick X and Y (typically axes 0 and 1)
y_input = self.joystick.get_axis(0) # Up/Down (often inverted)
x_input = self.joystick.get_axis(1) # Left/Right
# Right stick Y (typically axis 3 or 4)
z_input = self.joystick.get_axis(3) # Up/Down for Z
# Apply deadzone to avoid drift
x_input = 0 if abs(x_input) < self.deadzone else x_input
y_input = 0 if abs(y_input) < self.deadzone else y_input
z_input = 0 if abs(z_input) < self.deadzone else z_input
if self._invert_left_x:
x_input = -x_input
if self._invert_left_y:
y_input = -y_input
if self._invert_right_y:
z_input = -z_input
delta_x = y_input * self.y_step_size
delta_y = x_input * self.x_step_size
delta_z = z_input * self.z_step_size
# Calculate deltas (note: may need to invert axes depending on controller)
delta_x = -x_input * self.x_step_size # Forward/backward
delta_y = -y_input * self.y_step_size # Left/right
delta_z = -z_input * self.z_step_size # Up/down
return delta_x, delta_y, delta_z
@@ -341,15 +309,7 @@ class GamepadController(InputController):
class GamepadControllerHID(InputController):
"""Generate motion deltas from gamepad input using HIDAPI.
Supports auto-detection of controller type for correct HID report parsing.
Currently supported: Logitech RumblePad 2, 8BitDo Ultimate 2C Wireless.
"""
CONTROLLER_LOGITECH = "logitech"
CONTROLLER_8BITDO = "8bitdo"
CONTROLLER_UNKNOWN = "unknown"
"""Generate motion deltas from gamepad input using HIDAPI."""
def __init__(
self,
@@ -358,26 +318,36 @@ class GamepadControllerHID(InputController):
z_step_size=1.0,
deadzone=0.1,
):
"""
Initialize the HID gamepad controller.
Args:
step_size: Base movement step size in meters
z_scale: Scaling factor for Z-axis movement
deadzone: Joystick deadzone to prevent drift
"""
super().__init__(x_step_size, y_step_size, z_step_size)
self.deadzone = deadzone
self.device = None
self.device_info = None
self.controller_type = self.CONTROLLER_UNKNOWN
# Movement values (normalized from -1.0 to 1.0)
self.left_x = 0.0
self.left_y = 0.0
self.right_x = 0.0
self.right_y = 0.0
# Button states
self.buttons = {}
def find_device(self):
"""Look for the gamepad device by vendor and product ID."""
import hid
devices = hid.enumerate()
for device in devices:
device_name = device["product_string"]
if any(controller in device_name for controller in ["Logitech", "Xbox", "PS4", "PS5", "8BitDo"]):
if any(controller in device_name for controller in ["Logitech", "Xbox", "PS4", "PS5"]):
return device
logging.error(
@@ -385,15 +355,8 @@ class GamepadControllerHID(InputController):
)
return None
def _detect_controller_type(self, product_string):
product = product_string.lower() if product_string else ""
if "8bitdo" in product:
return self.CONTROLLER_8BITDO
elif "logitech" in product:
return self.CONTROLLER_LOGITECH
return self.CONTROLLER_UNKNOWN
def start(self):
"""Connect to the gamepad using HIDAPI."""
import hid
self.device_info = self.find_device()
@@ -411,22 +374,12 @@ class GamepadControllerHID(InputController):
product = self.device.get_product_string()
logging.info(f"Connected to {manufacturer} {product}")
self.controller_type = self._detect_controller_type(product)
logging.info(f"Detected controller type: {self.controller_type}")
print("Gamepad controls (HID mode):")
print(" Left analog stick: Move in X-Y plane")
print(" Right analog stick: Move in Z axis (vertical)")
print(" RB: Intervention toggle")
if self.controller_type == self.CONTROLLER_8BITDO:
print(" L3 (left stick click): Close gripper")
print(" R3 (right stick click): Open gripper")
else:
print(" LT: Close gripper")
print(" RT: Open gripper")
print(" Y: End episode with SUCCESS")
print(" X: End episode with FAILURE")
print(" A: Rerecord episode")
logging.info("Gamepad controls (HID mode):")
logging.info(" Left analog stick: Move in X-Y plane")
logging.info(" Right analog stick: Move in Z axis (vertical)")
logging.info(" Button 1/B/Circle: Exit")
logging.info(" Button 2/A/Cross: End episode with SUCCESS")
logging.info(" Button 3/X/Square: End episode with FAILURE")
except OSError as e:
logging.error(f"Error opening gamepad: {e}")
@@ -434,124 +387,74 @@ class GamepadControllerHID(InputController):
self.running = False
def stop(self):
"""Close the HID device connection."""
if self.device:
self.device.close()
self.device = None
def update(self):
"""Read the device several times to drain the HID buffer and get a stable reading."""
"""
Read and process the latest gamepad data.
Due to an issue with the HIDAPI, we need to read the read the device several times in order to get a stable reading
"""
for _ in range(10):
self._update()
def _update(self):
"""Read and process the latest gamepad data."""
if not self.device or not self.running:
return
try:
# Read data from the gamepad
data = self.device.read(64)
if not data:
return
# Interpret gamepad data - this will vary by controller model
# These offsets are for the Logitech RumblePad 2
if data and len(data) >= 8:
# Normalize joystick values from 0-255 to -1.0-1.0
self.left_y = (data[1] - 128) / 128.0
self.left_x = (data[2] - 128) / 128.0
self.right_x = (data[3] - 128) / 128.0
self.right_y = (data[4] - 128) / 128.0
if self.controller_type == self.CONTROLLER_8BITDO:
self._parse_8bitdo(data)
else:
self._parse_logitech(data)
# Apply deadzone
self.left_y = 0 if abs(self.left_y) < self.deadzone else self.left_y
self.left_x = 0 if abs(self.left_x) < self.deadzone else self.left_x
self.right_x = 0 if abs(self.right_x) < self.deadzone else self.right_x
self.right_y = 0 if abs(self.right_y) < self.deadzone else self.right_y
# Parse button states (byte 5 in the Logitech RumblePad 2)
buttons = data[5]
# Check if RB is pressed then the intervention flag should be set
self.intervention_flag = data[6] in [2, 6, 10, 14]
# Check if RT is pressed
self.open_gripper_command = data[6] in [8, 10, 12]
# Check if LT is pressed
self.close_gripper_command = data[6] in [4, 6, 12]
# Check if Y/Triangle button (bit 7) is pressed for saving
# Check if X/Square button (bit 5) is pressed for failure
# Check if A/Cross button (bit 4) is pressed for rerecording
if buttons & 1 << 7:
self.episode_end_status = TeleopEvents.SUCCESS
elif buttons & 1 << 5:
self.episode_end_status = TeleopEvents.FAILURE
elif buttons & 1 << 4:
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
else:
self.episode_end_status = None
except OSError as e:
logging.error(f"Error reading from gamepad: {e}")
def _apply_deadzone(self):
self.left_x = 0 if abs(self.left_x) < self.deadzone else self.left_x
self.left_y = 0 if abs(self.left_y) < self.deadzone else self.left_y
self.right_x = 0 if abs(self.right_x) < self.deadzone else self.right_x
self.right_y = 0 if abs(self.right_y) < self.deadzone else self.right_y
def _parse_8bitdo(self, data):
"""Parse HID report from 8BitDo Ultimate 2C Wireless (Bluetooth on macOS).
11-byte report layout:
byte[0]: Report ID (0x01)
byte[1]: D-pad hat switch (0=N, 2=E, 5=S, 6=W, 15=neutral)
byte[2]: Left Stick X (0=left, 127=center, 255=right)
byte[3]: Left Stick Y (0=up, 127=center, 255=down)
byte[4]: Right Stick X (inverted: 255=left, 0=right)
byte[5]: Right Stick Y (0=up, 127=center, 255=down)
byte[6]: RT analog trigger (0-255)
byte[7]: LT analog trigger (0-255)
byte[8]: Buttons -- bit0=A, bit1=B, bit3=X, bit4=Y, bit6=LB, bit7=RB
byte[9]: System -- bit0=LT(digital), bit1=RT(digital), bit3=Select,
bit4=Start, bit5=L3, bit6=R3
byte[10]: Unused
"""
if len(data) < 11:
return
self.left_x = (data[2] - 127) / 128.0
self.left_y = (data[3] - 127) / 128.0
self.right_x = -(data[4] - 127) / 128.0
self.right_y = (data[5] - 127) / 128.0
self._apply_deadzone()
buttons = data[8]
# RB (bit 7) = intervention
self.intervention_flag = bool(buttons & 0x80)
# Stick clicks for gripper: R3 (byte[9] bit6) = open, L3 (byte[9] bit5) = close
system = data[9]
self.open_gripper_command = bool(system & 0x40) # R3
self.close_gripper_command = bool(system & 0x20) # L3
# Y (bit 4) = success, X (bit 3) = failure, A (bit 0) = rerecord
if buttons & 0x10:
self.episode_end_status = TeleopEvents.SUCCESS
elif buttons & 0x08:
self.episode_end_status = TeleopEvents.FAILURE
elif buttons & 0x01:
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
else:
self.episode_end_status = None
def _parse_logitech(self, data):
"""Parse HID report from Logitech RumblePad 2 (and similar Logitech gamepads).
Report layout (8+ bytes):
byte[1]: Left Stick X (0-255, center=128)
byte[2]: Left Stick Y (0-255, center=128)
byte[3]: Right Stick X (0-255, center=128)
byte[4]: Right Stick Y (0-255, center=128)
byte[5]: Face buttons bitmask
byte[6]: Shoulder/trigger buttons bitmask
"""
if len(data) < 8:
return
self.left_x = (data[1] - 128) / 128.0
self.left_y = (data[2] - 128) / 128.0
self.right_x = (data[3] - 128) / 128.0
self.right_y = (data[4] - 128) / 128.0
self._apply_deadzone()
buttons = data[5]
self.intervention_flag = data[6] in [2, 6, 10, 14]
self.open_gripper_command = data[6] in [8, 10, 12]
self.close_gripper_command = data[6] in [4, 6, 12]
if buttons & 1 << 7:
self.episode_end_status = TeleopEvents.SUCCESS
elif buttons & 1 << 5:
self.episode_end_status = TeleopEvents.FAILURE
elif buttons & 1 << 4:
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
else:
self.episode_end_status = None
def get_deltas(self):
delta_x = -self.left_y * self.x_step_size
delta_y = -self.left_x * self.y_step_size
delta_z = -self.right_y * self.z_step_size
"""Get the current movement deltas from gamepad state."""
# Calculate deltas - invert as needed based on controller orientation
delta_x = -self.left_x * self.x_step_size # Forward/backward
delta_y = -self.left_y * self.y_step_size # Left/right
delta_z = -self.right_y * self.z_step_size # Up/down
return delta_x, delta_y, delta_z

View File

@@ -341,8 +341,8 @@ class KeyboardRoverTeleop(KeyboardTeleop):
def action_features(self) -> dict:
"""Return action format for rover (linear and angular velocities)."""
return {
"linear_velocity": float,
"angular_velocity": float,
"linear.vel": float,
"angular.vel": float,
}
@property
@@ -366,7 +366,7 @@ class KeyboardRoverTeleop(KeyboardTeleop):
Get the current action based on pressed keys.
Returns:
RobotAction with 'linear_velocity' and 'angular_velocity' keys.
RobotAction with 'linear.vel' and 'angular.vel' keys
"""
before_read_t = time.perf_counter()
@@ -427,6 +427,6 @@ class KeyboardRoverTeleop(KeyboardTeleop):
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
return {
"linear_velocity": linear_velocity,
"angular_velocity": angular_velocity,
"linear.vel": linear_velocity,
"angular.vel": angular_velocity,
}

View File

@@ -42,8 +42,6 @@ from lerobot.policies.factory import (
make_pre_post_processors,
)
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.policies.vqbet.modeling_vqbet import VQBeTHead
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
from lerobot.utils.random_utils import seeded_context
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
@@ -462,45 +460,3 @@ def test_act_temporal_ensembler():
assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max"))
# Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error.
torch.testing.assert_close(online_avg, offline_avg, rtol=1e-4, atol=1e-4)
def test_vqbet_discretize_keeps_buffers_on_device():
"""Regression test: VQBeTHead.discretize() must not move registered buffers off the model device.
Previously, `self.vqvae_model.discretized = torch.tensor(True)` replaced the
registered buffer with a new CPU tensor, causing DDP to crash with:
RuntimeError: No backend type associated with device type cpu
The fix uses `.fill_(True)` to update in-place, preserving device placement.
"""
config = VQBeTConfig()
config.input_features = {
OBS_IMAGES: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 96, 96)),
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(6,)),
}
config.output_features = {
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(6,)),
}
# Tiny sizes for fast CPU/GPU execution.
config.n_vqvae_training_steps = 3
config.vqvae_n_embed = 8
config.vqvae_embedding_dim = 32
config.vqvae_enc_hidden_dim = 32
config.action_chunk_size = 2
config.crop_shape = (84, 84)
head = VQBeTHead(config).to(DEVICE)
vqvae = head.vqvae_model
dummy_actions = torch.randn(4, config.action_chunk_size, config.action_feature.shape[0], device=DEVICE)
n_steps = config.n_vqvae_training_steps
for _ in range(n_steps):
head.discretize(n_steps, dummy_actions)
assert vqvae.discretized.device.type == torch.device(DEVICE).type, (
"vqvae_model.discretized was moved off the model device after discretize(). "
"Use .fill_(True) instead of = torch.tensor(True) to keep the buffer on device."
)
assert vqvae.vq_layer.freeze_codebook.device.type == torch.device(DEVICE).type, (
"vq_layer.freeze_codebook was moved off the model device after discretize(). "
"Use .fill_(True) instead of = torch.tensor(True) to keep the buffer on device."
)