mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
Compare commits
3 Commits
fix/re-ena
...
chore/data
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d0ae3e9481 | ||
|
|
26d732c8c8 | ||
|
|
c7458c67cd |
@@ -19,8 +19,6 @@
|
||||
title: Multi GPU training
|
||||
- local: peft_training
|
||||
title: Training with PEFT (e.g., LoRA)
|
||||
- local: rename_map
|
||||
title: Using Rename Map and Empty Cameras
|
||||
title: "Tutorials"
|
||||
- sections:
|
||||
- local: lerobot-dataset-v3
|
||||
|
||||
@@ -310,4 +310,4 @@ Asynchronous inference represents a significant advancement in real-time robotic
|
||||
- **Universal Compatibility**: Works with all LeRobot-supported policies, from lightweight ACT models to vision-language models like SmolVLA
|
||||
|
||||
Start experimenting with the default parameters, monitor your action queue sizes, and iteratively refine your setup to achieve optimal performance for your specific use case.
|
||||
If you want to discuss this further, hop into our [Discord community](https://discord.gg/s3KuuzsPFb), or open an issue on our [GitHub repository](https://github.com/huggingface/lerobot/issues).
|
||||
If you want to discuss this further, hop into our [Discord community](https://discord.gg/s3KuuzsPFb), or open an issue on our [GitHub repository](https://github.com/lerobot/lerobot/issues).
|
||||
|
||||
@@ -204,26 +204,22 @@ Replace `your_username/dataset_name` with your Hugging Face username and a name
|
||||
|
||||
Your dataset includes:
|
||||
|
||||
**Your Actions (2 features)**:
|
||||
**Your Actions (2 things)**:
|
||||
|
||||
- `linear_velocity`: How much you moved forward/backward
|
||||
- `angular_velocity`: How much you turned left/right
|
||||
- How much you moved forward/backward
|
||||
- How much you turned left/right
|
||||
|
||||
**Robot Observations (24 features)**:
|
||||
**Robot Observations (12 things)**:
|
||||
|
||||
- Front camera video
|
||||
- Rear camera video
|
||||
- Current speed
|
||||
- Battery level
|
||||
- Orientation
|
||||
- GPS (latitude, longitude, signal strength)
|
||||
- Which way the robot is facing
|
||||
- GPS location (latitude, longitude, signal strength)
|
||||
- Network signal strength
|
||||
- Vibration level
|
||||
- Lamp state (on/off)
|
||||
- Accelerometer (x, y, z)
|
||||
- Gyroscope (x, y, z)
|
||||
- Magnetometer (x, y, z)
|
||||
- Wheel RPMs (4 wheels)
|
||||
- Lamp status (on/off)
|
||||
|
||||
### Where Your Data Goes
|
||||
|
||||
|
||||
@@ -1,114 +0,0 @@
|
||||
# Rename Map and Empty Cameras
|
||||
|
||||
When you train, evaluate, or record with a robot policy, your **dataset** or **environment** provides observations under one set of keys (e.g. `observation.images.front`, `observation.images.eagle`), while your **policy** expects another (e.g. `observation.images.image`, `observation.images.image2`). The **rename map** bridges that gap without changing the policy or data source.
|
||||
|
||||
> **Scope:** The rename map only renames **observation** keys (images and state). Action keys are not affected.
|
||||
|
||||
## Why observation keys don't always match
|
||||
|
||||
Policies have a fixed set of **input feature names** baked into their pretrained config. For example:
|
||||
|
||||
- [pi0fast-libero](https://huggingface.co/lerobot/pi0fast-libero) expects `observation.images.base_0_rgb` and `observation.images.left_wrist_0_rgb`.
|
||||
- [xvla-base](https://huggingface.co/lerobot/xvla-base) expects `observation.images.image`, `observation.images.image2`, and `observation.images.image3`.
|
||||
|
||||
Your dataset might use different names entirely (e.g. `observation.images.front`, `observation.images.eagle`, `observation.images.glove`), and your eval environment might use yet another set. Rather than editing the policy config or renaming columns in the dataset, you pass a **rename map**: a JSON dictionary that maps source keys to the keys the policy expects. Renaming happens inside the preprocessor pipeline, so the policy always sees its expected keys.
|
||||
|
||||
## Using the rename map
|
||||
|
||||
Pass the mapping as a JSON string on the command line. The convention is always:
|
||||
|
||||
```
|
||||
--rename_map='{"source_key": "policy_key", ...}'
|
||||
```
|
||||
|
||||
where **source_key** is what the dataset or environment provides, and **policy_key** is what the policy expects.
|
||||
|
||||
Only listed keys are renamed; everything else passes through unchanged. Order of entries doesn't matter.
|
||||
|
||||
Supported policies: **PI0**, **PI05**, **PI0Fast**, **SmolVLA**, and **XVLA**.
|
||||
|
||||
### Training
|
||||
|
||||
Suppose you fine-tune [lerobot/xvla-base](https://huggingface.co/lerobot/xvla-base) on a dataset with images under `observation.images.front`, `observation.images.eagle`, and `observation.images.glove`. XVLA expects `observation.images.image`, `observation.images.image2`, and `observation.images.image3`:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=YOUR_DATASET \
|
||||
--output_dir=./outputs/xvla_training \
|
||||
--job_name=xvla_training \
|
||||
--policy.path="lerobot/xvla-base" \
|
||||
--policy.repo_id="HF_USER/xvla-your-robot" \
|
||||
--policy.dtype=bfloat16 \
|
||||
--policy.action_mode=auto \
|
||||
--steps=20000 \
|
||||
--policy.device=cuda \
|
||||
--policy.freeze_vision_encoder=false \
|
||||
--policy.freeze_language_encoder=false \
|
||||
--policy.train_policy_transformer=true \
|
||||
--policy.train_soft_prompts=true \
|
||||
--rename_map='{"observation.images.front": "observation.images.image", "observation.images.eagle": "observation.images.image2", "observation.images.glove": "observation.images.image3"}'
|
||||
```
|
||||
|
||||
### Evaluation
|
||||
|
||||
A policy that expects `observation.images.base_0_rgb` and `observation.images.left_wrist_0_rgb` (e.g. [pi0fast-libero](https://huggingface.co/lerobot/pi0fast-libero)), but the LIBERO environment returns `observation.images.image` and `observation.images.image2`:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/pi0fast-libero \
|
||||
--env.type=libero \
|
||||
... \
|
||||
--rename_map='{"observation.images.image": "observation.images.base_0_rgb", "observation.images.image2": "observation.images.left_wrist_0_rgb"}'
|
||||
```
|
||||
|
||||
### Recording
|
||||
|
||||
`lerobot-record` also supports rename maps, nested under the dataset config:
|
||||
|
||||
```bash
|
||||
lerobot-record \ # When running inference
|
||||
--policy.path="<user>/smolVLA_finetuned" \
|
||||
... \
|
||||
--dataset.rename_map='{"observation.images.glove2": "observation.images.image"}'
|
||||
```
|
||||
|
||||
## Alternative: edit the policy config directly
|
||||
|
||||
If you always use the same dataset or environment, you can **edit the policy's `config.json`** so its observation keys match your data source. Then no rename map is needed.
|
||||
|
||||
The tradeoff: modifying the policy config ties it to one data source. A rename map keeps one policy usable across many datasets and environments.
|
||||
|
||||
## Empty cameras: fewer views than the policy expects
|
||||
|
||||
Some policies are built for a fixed number of image inputs. If your dataset has fewer cameras, you can set **`empty_cameras`** in the policy config instead of modifying the model architecture.
|
||||
|
||||
### How it works
|
||||
|
||||
Setting `empty_cameras=N` adds N placeholder image features to the policy config, named:
|
||||
|
||||
```
|
||||
observation.images.empty_camera_0
|
||||
observation.images.empty_camera_1
|
||||
...
|
||||
```
|
||||
|
||||
At runtime, these keys have no corresponding data in the batch. The policy fills them with masked dummy tensors (padded with `-1` for SigLIP-based vision encoders, with a zero attention mask), so the extra image slots are effectively ignored during training and inference.
|
||||
|
||||
### Example
|
||||
|
||||
XVLA-base has three visual inputs and `empty_cameras=0` by default. Your dataset only has two cameras:
|
||||
|
||||
1. Set `--policy.empty_cameras=1`.
|
||||
2. The config adds a third key: `observation.images.empty_camera_0`.
|
||||
3. Use the rename map for your two real cameras as usual.
|
||||
4. The third slot is masked out — no fake images needed in your dataset.
|
||||
|
||||
## Quick reference
|
||||
|
||||
| Goal | What to do |
|
||||
| ----------------------------------------- | --------------------------------------------------------------------------- |
|
||||
| Dataset keys ≠ policy keys | `--rename_map='{"dataset_key": "policy_key", ...}'` |
|
||||
| Env keys ≠ policy keys (eval) | `--rename_map='{"env_key": "policy_key", ...}'` |
|
||||
| Recording with different keys (inference) | `--dataset.rename_map='{"source_key": "policy_key", ...}'`. |
|
||||
| Fewer cameras than policy expects | `--policy.empty_cameras=N` (supported by PI0, PI05, PI0Fast, SmolVLA, XVLA) |
|
||||
| Avoid passing a rename map | Edit the policy's `config.json` so its keys match your data source |
|
||||
@@ -467,8 +467,8 @@ class VQBeTHead(nn.Module):
|
||||
self.vqvae_model.optimized_steps += 1
|
||||
# if we updated RVQ more than `n_vqvae_training_steps` steps, we freeze the RVQ part.
|
||||
if self.vqvae_model.optimized_steps >= n_vqvae_training_steps:
|
||||
self.vqvae_model.discretized.fill_(True)
|
||||
self.vqvae_model.vq_layer.freeze_codebook.fill_(True)
|
||||
self.vqvae_model.discretized = torch.tensor(True)
|
||||
self.vqvae_model.vq_layer.freeze_codebook = torch.tensor(True)
|
||||
print("Finished discretizing action data!")
|
||||
self.vqvae_model.eval()
|
||||
for param in self.vqvae_model.vq_layer.parameters():
|
||||
|
||||
@@ -131,15 +131,6 @@ class _NormalizationMixin:
|
||||
if self.dtype is None:
|
||||
self.dtype = torch.float32
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
|
||||
self._reshape_visual_stats()
|
||||
|
||||
def _reshape_visual_stats(self) -> None:
|
||||
"""Reshape visual stats from ``[C]`` to ``[C, 1, 1]`` for image broadcasting."""
|
||||
for key, feature in self.features.items():
|
||||
if feature.type == FeatureType.VISUAL and key in self._tensor_stats:
|
||||
for stat_name, stat_tensor in self._tensor_stats[key].items():
|
||||
if isinstance(stat_tensor, Tensor) and stat_tensor.ndim == 1:
|
||||
self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1)
|
||||
|
||||
def to(
|
||||
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
|
||||
@@ -158,7 +149,6 @@ class _NormalizationMixin:
|
||||
if dtype is not None:
|
||||
self.dtype = dtype
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
|
||||
self._reshape_visual_stats()
|
||||
return self
|
||||
|
||||
def state_dict(self) -> dict[str, Tensor]:
|
||||
@@ -208,7 +198,6 @@ class _NormalizationMixin:
|
||||
# Don't load from state_dict, keep the explicitly provided stats
|
||||
# But ensure _tensor_stats is properly initialized
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment]
|
||||
self._reshape_visual_stats()
|
||||
return
|
||||
|
||||
# Normal behavior: load stats from state_dict
|
||||
@@ -220,8 +209,6 @@ class _NormalizationMixin:
|
||||
dtype=torch.float32, device=self.device
|
||||
)
|
||||
|
||||
self._reshape_visual_stats()
|
||||
|
||||
# Reconstruct the original stats dict from tensor stats for compatibility with to() method
|
||||
# and other functions that rely on self.stats
|
||||
self.stats = {}
|
||||
|
||||
@@ -62,7 +62,6 @@ from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.rl.queue import get_last_item_from_queue
|
||||
from lerobot.robots import so_follower # noqa: F401
|
||||
@@ -259,11 +258,6 @@ def act_with_policy(
|
||||
policy = policy.eval()
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config=cfg.policy,
|
||||
dataset_stats=cfg.policy.dataset_stats,
|
||||
)
|
||||
|
||||
obs, info = online_env.reset()
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
@@ -295,9 +289,7 @@ def act_with_policy(
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with policy_timer:
|
||||
# Extract observation from transition for policy
|
||||
normalized_observation = preprocessor.process_observation(observation)
|
||||
action = policy.select_action(batch=normalized_observation)
|
||||
# action = postprocessor.process_action(action)
|
||||
action = policy.select_action(batch=observation)
|
||||
policy_fps = policy_timer.fps_last
|
||||
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
|
||||
@@ -66,7 +66,6 @@ from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
|
||||
from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.rl.wandb_utils import WandBLogger
|
||||
@@ -314,11 +313,6 @@ def add_actor_information_and_train(
|
||||
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
preprocessor, _ = make_sac_pre_post_processors(
|
||||
config=cfg.policy,
|
||||
dataset_stats=cfg.policy.dataset_stats,
|
||||
)
|
||||
|
||||
policy.train()
|
||||
|
||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||
@@ -414,9 +408,6 @@ def add_actor_information_and_train(
|
||||
done = batch["done"]
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
observations = preprocessor.process_observation(observations)
|
||||
next_observations = preprocessor.process_observation(next_observations)
|
||||
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
policy=policy, observations=observations, next_observations=next_observations
|
||||
)
|
||||
@@ -476,9 +467,6 @@ def add_actor_information_and_train(
|
||||
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
observations = preprocessor.process_observation(observations)
|
||||
next_observations = preprocessor.process_observation(next_observations)
|
||||
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
policy=policy, observations=observations, next_observations=next_observations
|
||||
)
|
||||
|
||||
@@ -33,40 +33,21 @@ from .config_earthrover_mini_plus import EarthRoverMiniPlusConfig
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Action feature keys
|
||||
ACTION_LINEAR_VEL = "linear_velocity"
|
||||
ACTION_ANGULAR_VEL = "angular_velocity"
|
||||
ACTION_LINEAR_VEL = "linear.vel"
|
||||
ACTION_ANGULAR_VEL = "angular.vel"
|
||||
|
||||
# Observation feature keys — cameras
|
||||
# Observation feature keys
|
||||
OBS_FRONT = "front"
|
||||
OBS_REAR = "rear"
|
||||
|
||||
# Observation feature keys — telemetry
|
||||
OBS_SPEED = "speed"
|
||||
OBS_BATTERY_LEVEL = "battery_level"
|
||||
OBS_ORIENTATION = "orientation"
|
||||
OBS_GPS_LATITUDE = "gps_latitude"
|
||||
OBS_GPS_LONGITUDE = "gps_longitude"
|
||||
OBS_GPS_SIGNAL = "gps_signal"
|
||||
OBS_SIGNAL_LEVEL = "signal_level"
|
||||
OBS_LINEAR_VEL = "linear.vel"
|
||||
OBS_BATTERY_LEVEL = "battery.level"
|
||||
OBS_ORIENTATION_DEG = "orientation.deg"
|
||||
OBS_GPS_LATITUDE = "gps.latitude"
|
||||
OBS_GPS_LONGITUDE = "gps.longitude"
|
||||
OBS_GPS_SIGNAL = "gps.signal"
|
||||
OBS_SIGNAL_LEVEL = "signal.level"
|
||||
OBS_VIBRATION = "vibration"
|
||||
OBS_LAMP = "lamp"
|
||||
|
||||
# Observation feature keys — IMU sensors
|
||||
OBS_ACCELEROMETER_X = "accelerometer_x"
|
||||
OBS_ACCELEROMETER_Y = "accelerometer_y"
|
||||
OBS_ACCELEROMETER_Z = "accelerometer_z"
|
||||
OBS_GYROSCOPE_X = "gyroscope_x"
|
||||
OBS_GYROSCOPE_Y = "gyroscope_y"
|
||||
OBS_GYROSCOPE_Z = "gyroscope_z"
|
||||
OBS_MAGNETOMETER_X = "magnetometer_filtered_x"
|
||||
OBS_MAGNETOMETER_Y = "magnetometer_filtered_y"
|
||||
OBS_MAGNETOMETER_Z = "magnetometer_filtered_z"
|
||||
|
||||
# Observation feature keys — wheel RPMs
|
||||
OBS_WHEEL_RPM_0 = "wheel_rpm_0"
|
||||
OBS_WHEEL_RPM_1 = "wheel_rpm_1"
|
||||
OBS_WHEEL_RPM_2 = "wheel_rpm_2"
|
||||
OBS_WHEEL_RPM_3 = "wheel_rpm_3"
|
||||
OBS_LAMP_STATE = "lamp.state"
|
||||
|
||||
|
||||
class EarthRoverMiniPlus(Robot):
|
||||
@@ -173,60 +154,33 @@ class EarthRoverMiniPlus(Robot):
|
||||
dict: Observation features with types/shapes:
|
||||
- front: (480, 640, 3) - Front camera RGB image
|
||||
- rear: (480, 640, 3) - Rear camera RGB image
|
||||
- speed: float - Current speed (raw SDK value)
|
||||
- battery_level: float - Battery level (0-100)
|
||||
- orientation: float - Robot orientation in degrees
|
||||
- gps_latitude: float - GPS latitude coordinate
|
||||
- gps_longitude: float - GPS longitude coordinate
|
||||
- gps_signal: float - GPS signal strength (percentage)
|
||||
- signal_level: float - Network signal level (0-5)
|
||||
- linear.vel: float - Current speed (0-1, SDK reports only positive speeds)
|
||||
- battery.level: float - Battery level (0-1, normalized from 0-100)
|
||||
- orientation.deg: float - Robot orientation (0-1, normalized from raw value)
|
||||
- gps.latitude: float - GPS latitude coordinate
|
||||
- gps.longitude: float - GPS longitude coordinate
|
||||
- gps.signal: float - GPS signal strength (0-1, normalized from percentage)
|
||||
- signal.level: float - Network signal level (0-1, normalized from 0-5)
|
||||
- vibration: float - Vibration sensor reading
|
||||
- lamp: float - Lamp state (0=off, 1=on)
|
||||
- accelerometer_x: float - Accelerometer X axis (raw SDK value)
|
||||
- accelerometer_y: float - Accelerometer Y axis (raw SDK value)
|
||||
- accelerometer_z: float - Accelerometer Z axis (raw SDK value)
|
||||
- gyroscope_x: float - Gyroscope X axis (raw SDK value)
|
||||
- gyroscope_y: float - Gyroscope Y axis (raw SDK value)
|
||||
- gyroscope_z: float - Gyroscope Z axis (raw SDK value)
|
||||
- magnetometer_filtered_x: float - Magnetometer X axis (raw SDK value)
|
||||
- magnetometer_filtered_y: float - Magnetometer Y axis (raw SDK value)
|
||||
- magnetometer_filtered_z: float - Magnetometer Z axis (raw SDK value)
|
||||
- wheel_rpm_0: float - Wheel 0 RPM
|
||||
- wheel_rpm_1: float - Wheel 1 RPM
|
||||
- wheel_rpm_2: float - Wheel 2 RPM
|
||||
- wheel_rpm_3: float - Wheel 3 RPM
|
||||
- lamp.state: float - Lamp state (0=off, 1=on)
|
||||
"""
|
||||
return {
|
||||
# Cameras (height, width, channels)
|
||||
OBS_FRONT: (480, 640, 3),
|
||||
OBS_REAR: (480, 640, 3),
|
||||
# Telemetry
|
||||
OBS_SPEED: float,
|
||||
# Motion state
|
||||
OBS_LINEAR_VEL: float,
|
||||
# Robot state
|
||||
OBS_BATTERY_LEVEL: float,
|
||||
OBS_ORIENTATION: float,
|
||||
OBS_ORIENTATION_DEG: float,
|
||||
# GPS
|
||||
OBS_GPS_LATITUDE: float,
|
||||
OBS_GPS_LONGITUDE: float,
|
||||
OBS_GPS_SIGNAL: float,
|
||||
# Sensors
|
||||
OBS_SIGNAL_LEVEL: float,
|
||||
OBS_VIBRATION: float,
|
||||
OBS_LAMP: float,
|
||||
# IMU — accelerometer
|
||||
OBS_ACCELEROMETER_X: float,
|
||||
OBS_ACCELEROMETER_Y: float,
|
||||
OBS_ACCELEROMETER_Z: float,
|
||||
# IMU — gyroscope
|
||||
OBS_GYROSCOPE_X: float,
|
||||
OBS_GYROSCOPE_Y: float,
|
||||
OBS_GYROSCOPE_Z: float,
|
||||
# IMU — magnetometer
|
||||
OBS_MAGNETOMETER_X: float,
|
||||
OBS_MAGNETOMETER_Y: float,
|
||||
OBS_MAGNETOMETER_Z: float,
|
||||
# Wheel RPMs
|
||||
OBS_WHEEL_RPM_0: float,
|
||||
OBS_WHEEL_RPM_1: float,
|
||||
OBS_WHEEL_RPM_2: float,
|
||||
OBS_WHEEL_RPM_3: float,
|
||||
OBS_LAMP_STATE: float,
|
||||
}
|
||||
|
||||
@cached_property
|
||||
@@ -235,8 +189,8 @@ class EarthRoverMiniPlus(Robot):
|
||||
|
||||
Returns:
|
||||
dict: Action features with types:
|
||||
- linear_velocity: float - Target linear velocity (-1 to 1)
|
||||
- angular_velocity: float - Target angular velocity (-1 to 1)
|
||||
- linear.vel: float - Target linear velocity
|
||||
- angular.vel: float - Target angular velocity
|
||||
"""
|
||||
return {
|
||||
ACTION_LINEAR_VEL: float,
|
||||
@@ -247,29 +201,19 @@ class EarthRoverMiniPlus(Robot):
|
||||
def get_observation(self) -> RobotObservation:
|
||||
"""Get current robot observation from SDK.
|
||||
|
||||
Camera frames are retrieved from SDK endpoints /v2/front and /v2/rear.
|
||||
Frames are decoded from base64 and converted from BGR to RGB format.
|
||||
Robot telemetry is retrieved from /data endpoint.
|
||||
Sensor arrays (accels, gyros, mags, rpms) each contain entries of
|
||||
[values..., timestamp]; the latest reading from each array is used.
|
||||
|
||||
Returns:
|
||||
RobotObservation: Observation containing:
|
||||
- front: Front camera image (480, 640, 3) in RGB format
|
||||
- rear: Rear camera image (480, 640, 3) in RGB format
|
||||
- speed: float - Current speed (raw SDK value)
|
||||
- battery_level: float - Battery level (0-100)
|
||||
- orientation: float - Robot orientation in degrees
|
||||
- gps_latitude: float - GPS latitude coordinate
|
||||
- gps_longitude: float - GPS longitude coordinate
|
||||
- gps_signal: float - GPS signal strength (percentage)
|
||||
- signal_level: float - Network signal level (0-5)
|
||||
- vibration: float - Vibration sensor reading
|
||||
- lamp: float - Lamp state (0=off, 1=on)
|
||||
- accelerometer_x/y/z: float - Accelerometer axes (raw SDK value)
|
||||
- gyroscope_x/y/z: float - Gyroscope axes (raw SDK value)
|
||||
- magnetometer_filtered_x/y/z: float - Magnetometer axes (raw SDK value)
|
||||
- wheel_rpm_0/1/2/3: float - Wheel RPMs
|
||||
- linear.vel: Current speed (0-1, SDK reports only positive speeds)
|
||||
- battery.level: Battery level (0-1, normalized from 0-100)
|
||||
- orientation.deg: Robot orientation (0-1, normalized from raw value)
|
||||
- gps.latitude: GPS latitude coordinate
|
||||
- gps.longitude: GPS longitude coordinate
|
||||
- gps.signal: GPS signal strength (0-1, normalized from percentage)
|
||||
- signal.level: Network signal level (0-1, normalized from 0-5)
|
||||
- vibration: Vibration sensor reading
|
||||
- lamp.state: Lamp state (0=off, 1=on)
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If robot is not connected
|
||||
@@ -291,41 +235,22 @@ class EarthRoverMiniPlus(Robot):
|
||||
# Get robot state from SDK
|
||||
robot_data = self._get_robot_data()
|
||||
|
||||
# Telemetry
|
||||
observation[OBS_SPEED] = float(robot_data["speed"])
|
||||
observation[OBS_BATTERY_LEVEL] = float(robot_data["battery"])
|
||||
observation[OBS_ORIENTATION] = float(robot_data["orientation"])
|
||||
observation[OBS_GPS_LATITUDE] = float(robot_data["latitude"])
|
||||
observation[OBS_GPS_LONGITUDE] = float(robot_data["longitude"])
|
||||
observation[OBS_GPS_SIGNAL] = float(robot_data["gps_signal"])
|
||||
observation[OBS_SIGNAL_LEVEL] = float(robot_data["signal_level"])
|
||||
observation[OBS_VIBRATION] = float(robot_data["vibration"])
|
||||
observation[OBS_LAMP] = float(robot_data["lamp"])
|
||||
# Motion state
|
||||
observation[OBS_LINEAR_VEL] = robot_data["speed"] / 100.0 # Normalize 0-100 to 0-1
|
||||
|
||||
# Accelerometer — latest reading from accels array [x, y, z, ts]
|
||||
accel = self._latest_sensor_reading(robot_data, "accels", n_values=3)
|
||||
observation[OBS_ACCELEROMETER_X] = accel[0]
|
||||
observation[OBS_ACCELEROMETER_Y] = accel[1]
|
||||
observation[OBS_ACCELEROMETER_Z] = accel[2]
|
||||
# Robot state
|
||||
observation[OBS_BATTERY_LEVEL] = robot_data["battery"] / 100.0 # Normalize 0-100 to 0-1
|
||||
observation[OBS_ORIENTATION_DEG] = robot_data["orientation"] / 360.0 # Normalize to 0-1
|
||||
|
||||
# Gyroscope — latest reading from gyros array [x, y, z, ts]
|
||||
gyro = self._latest_sensor_reading(robot_data, "gyros", n_values=3)
|
||||
observation[OBS_GYROSCOPE_X] = gyro[0]
|
||||
observation[OBS_GYROSCOPE_Y] = gyro[1]
|
||||
observation[OBS_GYROSCOPE_Z] = gyro[2]
|
||||
# GPS data
|
||||
observation[OBS_GPS_LATITUDE] = robot_data["latitude"]
|
||||
observation[OBS_GPS_LONGITUDE] = robot_data["longitude"]
|
||||
observation[OBS_GPS_SIGNAL] = robot_data["gps_signal"] / 100.0 # Normalize percentage to 0-1
|
||||
|
||||
# Magnetometer — latest reading from mags array [x, y, z, ts]
|
||||
mag = self._latest_sensor_reading(robot_data, "mags", n_values=3)
|
||||
observation[OBS_MAGNETOMETER_X] = mag[0]
|
||||
observation[OBS_MAGNETOMETER_Y] = mag[1]
|
||||
observation[OBS_MAGNETOMETER_Z] = mag[2]
|
||||
|
||||
# Wheel RPMs — latest reading from rpms array [w0, w1, w2, w3, ts]
|
||||
rpm = self._latest_sensor_reading(robot_data, "rpms", n_values=4)
|
||||
observation[OBS_WHEEL_RPM_0] = rpm[0]
|
||||
observation[OBS_WHEEL_RPM_1] = rpm[1]
|
||||
observation[OBS_WHEEL_RPM_2] = rpm[2]
|
||||
observation[OBS_WHEEL_RPM_3] = rpm[3]
|
||||
# Sensors
|
||||
observation[OBS_SIGNAL_LEVEL] = robot_data["signal_level"] / 5.0 # Normalize 0-5 to 0-1
|
||||
observation[OBS_VIBRATION] = robot_data["vibration"]
|
||||
observation[OBS_LAMP_STATE] = float(robot_data["lamp"]) # 0 or 1
|
||||
|
||||
return observation
|
||||
|
||||
@@ -335,12 +260,11 @@ class EarthRoverMiniPlus(Robot):
|
||||
|
||||
Args:
|
||||
action: Action dict with keys:
|
||||
- linear_velocity: Target linear velocity (-1 to 1)
|
||||
- angular_velocity: Target angular velocity (-1 to 1)
|
||||
- linear.vel: Target linear velocity (-1 to 1)
|
||||
- angular.vel: Target angular velocity (-1 to 1)
|
||||
|
||||
Returns:
|
||||
RobotAction: The action that was sent (matches action_features keys)
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If robot is not connected
|
||||
|
||||
@@ -348,14 +272,18 @@ class EarthRoverMiniPlus(Robot):
|
||||
Actions are sent to SDK via POST /control endpoint.
|
||||
SDK expects commands in range [-1, 1].
|
||||
"""
|
||||
|
||||
# Extract action values and convert to float
|
||||
linear = float(action.get(ACTION_LINEAR_VEL, 0.0))
|
||||
angular = float(action.get(ACTION_ANGULAR_VEL, 0.0))
|
||||
|
||||
# Send command to SDK
|
||||
try:
|
||||
self._send_command_to_sdk(linear, angular)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending action: {e}")
|
||||
|
||||
# Return action in format matching action_features
|
||||
return {
|
||||
ACTION_LINEAR_VEL: linear,
|
||||
ACTION_ANGULAR_VEL: angular,
|
||||
@@ -466,27 +394,11 @@ class EarthRoverMiniPlus(Robot):
|
||||
logger.error(f"Error decoding image: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _latest_sensor_reading(robot_data: dict, key: str, n_values: int) -> list[float]:
|
||||
"""Extract the latest sensor reading from an SDK sensor array.
|
||||
|
||||
The SDK returns sensor arrays like ``accels``, ``gyros``, ``mags``,
|
||||
``rpms`` where each entry is ``[value_0, ..., value_n, timestamp]``.
|
||||
This helper returns the *n_values* leading floats from the last entry,
|
||||
falling back to zeros when the key is missing or the array is empty.
|
||||
"""
|
||||
readings = robot_data.get(key)
|
||||
if readings and len(readings) > 0:
|
||||
latest = readings[-1]
|
||||
return [float(v) for v in latest[:n_values]]
|
||||
return [0.0] * n_values
|
||||
|
||||
def _get_robot_data(self) -> dict:
|
||||
"""Get robot telemetry data from SDK.
|
||||
|
||||
Returns:
|
||||
dict: Robot telemetry data including battery, speed, orientation, GPS,
|
||||
and sensor arrays (accels, gyros, mags, rpms):
|
||||
dict: Robot telemetry data including battery, speed, orientation, GPS, etc:
|
||||
- Current data (if request succeeds)
|
||||
- Cached data (if request fails but cache exists)
|
||||
- Default values (if request fails and no cache exists yet)
|
||||
@@ -508,23 +420,19 @@ class EarthRoverMiniPlus(Robot):
|
||||
# Fallback: use cache or default values
|
||||
if self._last_robot_data is not None:
|
||||
return self._last_robot_data
|
||||
|
||||
# Return dict with default values (used only on first failure before any cache exists)
|
||||
return {
|
||||
"speed": 0,
|
||||
"battery": 0,
|
||||
"orientation": 0,
|
||||
"latitude": 0.0,
|
||||
"longitude": 0.0,
|
||||
"gps_signal": 0,
|
||||
"signal_level": 0,
|
||||
"vibration": 0.0,
|
||||
"lamp": 0,
|
||||
"accels": [],
|
||||
"gyros": [],
|
||||
"mags": [],
|
||||
"rpms": [],
|
||||
}
|
||||
else:
|
||||
# Return dict with default values (used only on first failure before any cache exists)
|
||||
return {
|
||||
"speed": 0,
|
||||
"battery": 0,
|
||||
"orientation": 0,
|
||||
"latitude": 0.0,
|
||||
"longitude": 0.0,
|
||||
"gps_signal": 0,
|
||||
"signal_level": 0,
|
||||
"vibration": 0.0,
|
||||
"lamp": 0,
|
||||
}
|
||||
|
||||
def _send_command_to_sdk(self, linear: float, angular: float, lamp: int = 0) -> bool:
|
||||
"""Send control command to SDK.
|
||||
|
||||
@@ -23,46 +23,65 @@ class InputController:
|
||||
"""Base class for input controllers that generate motion deltas."""
|
||||
|
||||
def __init__(self, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0):
|
||||
"""
|
||||
Initialize the controller.
|
||||
|
||||
Args:
|
||||
x_step_size: Base movement step size in meters
|
||||
y_step_size: Base movement step size in meters
|
||||
z_step_size: Base movement step size in meters
|
||||
"""
|
||||
self.x_step_size = x_step_size
|
||||
self.y_step_size = y_step_size
|
||||
self.z_step_size = z_step_size
|
||||
self.running = True
|
||||
self.episode_end_status = None
|
||||
self.episode_end_status = None # None, "success", or "failure"
|
||||
self.intervention_flag = False
|
||||
self.open_gripper_command = False
|
||||
self.close_gripper_command = False
|
||||
|
||||
def start(self):
|
||||
"""Start the controller and initialize resources."""
|
||||
pass
|
||||
|
||||
def stop(self):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
"""Stop the controller and release resources."""
|
||||
pass
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas (dx, dy, dz) in meters."""
|
||||
return 0.0, 0.0, 0.0
|
||||
|
||||
def update(self):
|
||||
"""Update controller state - call this once per frame."""
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
"""Support for use in 'with' statements."""
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Ensure resources are released when exiting 'with' block."""
|
||||
self.stop()
|
||||
|
||||
def get_episode_end_status(self):
|
||||
"""
|
||||
Get the current episode end status.
|
||||
|
||||
Returns:
|
||||
None if episode should continue, "success" or "failure" otherwise
|
||||
"""
|
||||
status = self.episode_end_status
|
||||
self.episode_end_status = None
|
||||
self.episode_end_status = None # Reset after reading
|
||||
return status
|
||||
|
||||
def should_intervene(self):
|
||||
"""Return True if intervention flag was set."""
|
||||
return self.intervention_flag
|
||||
|
||||
def gripper_command(self):
|
||||
"""Return the current gripper command."""
|
||||
if self.open_gripper_command == self.close_gripper_command:
|
||||
return "stay"
|
||||
elif self.open_gripper_command:
|
||||
@@ -83,14 +102,14 @@ class KeyboardController(InputController):
|
||||
"backward_y": False,
|
||||
"forward_z": False,
|
||||
"backward_z": False,
|
||||
"quit": False,
|
||||
"success": False,
|
||||
"failure": False,
|
||||
"intervention": False,
|
||||
"rerecord": False,
|
||||
}
|
||||
self.listener = None
|
||||
|
||||
def start(self):
|
||||
"""Start the keyboard listener."""
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
@@ -107,21 +126,16 @@ class KeyboardController(InputController):
|
||||
self.key_states["backward_z"] = True
|
||||
elif key == keyboard.Key.shift_r:
|
||||
self.key_states["forward_z"] = True
|
||||
elif key == keyboard.Key.ctrl_r:
|
||||
self.open_gripper_command = True
|
||||
elif key == keyboard.Key.ctrl_l:
|
||||
self.close_gripper_command = True
|
||||
elif key == keyboard.Key.esc:
|
||||
self.key_states["quit"] = True
|
||||
self.running = False
|
||||
return False
|
||||
elif key == keyboard.Key.enter:
|
||||
self.key_states["success"] = True
|
||||
self.episode_end_status = TeleopEvents.SUCCESS
|
||||
elif key == keyboard.Key.esc:
|
||||
elif key == keyboard.Key.backspace:
|
||||
self.key_states["failure"] = True
|
||||
self.episode_end_status = TeleopEvents.FAILURE
|
||||
elif key == keyboard.Key.space:
|
||||
self.key_states["intervention"] = not self.key_states["intervention"]
|
||||
elif hasattr(key, "char") and key.char == "r":
|
||||
self.key_states["rerecord"] = True
|
||||
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@@ -139,10 +153,10 @@ class KeyboardController(InputController):
|
||||
self.key_states["backward_z"] = False
|
||||
elif key == keyboard.Key.shift_r:
|
||||
self.key_states["forward_z"] = False
|
||||
elif key == keyboard.Key.ctrl_r:
|
||||
self.open_gripper_command = False
|
||||
elif key == keyboard.Key.ctrl_l:
|
||||
self.close_gripper_command = False
|
||||
elif key == keyboard.Key.enter:
|
||||
self.key_states["success"] = False
|
||||
elif key == keyboard.Key.backspace:
|
||||
self.key_states["failure"] = False
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@@ -151,18 +165,18 @@ class KeyboardController(InputController):
|
||||
|
||||
print("Keyboard controls:")
|
||||
print(" Arrow keys: Move in X-Y plane")
|
||||
print(" Shift / Shift_R: Move in Z axis")
|
||||
print(" Ctrl_R / Ctrl_L: Open / Close gripper")
|
||||
print(" Space: Toggle intervention")
|
||||
print(" Shift and Shift_R: Move in Z axis")
|
||||
print(" Enter: End episode with SUCCESS")
|
||||
print(" Esc: End episode with FAILURE")
|
||||
print(" R: Rerecord episode")
|
||||
print(" Backspace: End episode with FAILURE")
|
||||
print(" ESC: Exit")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the keyboard listener."""
|
||||
if self.listener and self.listener.is_alive():
|
||||
self.listener.stop()
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas from keyboard state."""
|
||||
delta_x = delta_y = delta_z = 0.0
|
||||
|
||||
if self.key_states["forward_x"]:
|
||||
@@ -180,58 +194,18 @@ class KeyboardController(InputController):
|
||||
|
||||
return delta_x, delta_y, delta_z
|
||||
|
||||
def should_intervene(self):
|
||||
return self.key_states["intervention"]
|
||||
|
||||
def reset(self):
|
||||
for key in self.key_states:
|
||||
self.key_states[key] = False
|
||||
|
||||
|
||||
class GamepadController(InputController):
|
||||
"""Generate motion deltas from gamepad input using pygame.
|
||||
|
||||
Matches gym-hil button/axis conventions for Linux gamepads, including
|
||||
Xbox mappings.
|
||||
"""
|
||||
|
||||
# Face buttons (same across most controllers on Linux)
|
||||
BUTTON_A = 0
|
||||
BUTTON_B = 1
|
||||
BUTTON_X = 2
|
||||
BUTTON_Y = 3
|
||||
BUTTON_LB = 4
|
||||
BUTTON_RB = 5
|
||||
# Stick axes
|
||||
AXIS_LEFT_X = 0
|
||||
AXIS_LEFT_Y = 1
|
||||
AXIS_RIGHT_X = 2
|
||||
AXIS_RIGHT_Y = 3
|
||||
|
||||
# Default trigger buttons
|
||||
BUTTON_LT = 6
|
||||
BUTTON_RT = 7
|
||||
|
||||
# Xbox (gym-hil mapping on Linux)
|
||||
XBOX_BUTTON_LT = 9
|
||||
XBOX_BUTTON_RT = 10
|
||||
"""Generate motion deltas from gamepad input."""
|
||||
|
||||
def __init__(self, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0, deadzone=0.1):
|
||||
super().__init__(x_step_size, y_step_size, z_step_size)
|
||||
self.deadzone = deadzone
|
||||
self.joystick = None
|
||||
self.intervention_flag = False
|
||||
self.is_xbox = False
|
||||
self._xbox360_profile = False
|
||||
self._invert_left_x = False
|
||||
self._invert_left_y = True
|
||||
self._invert_right_y = True
|
||||
|
||||
def _detect_xbox(self, name):
|
||||
name_lower = name.lower()
|
||||
return any(tag in name_lower for tag in ["xbox", "microsoft", "x-box"])
|
||||
|
||||
def start(self):
|
||||
"""Initialize pygame and the gamepad."""
|
||||
import pygame
|
||||
|
||||
pygame.init()
|
||||
@@ -244,35 +218,18 @@ class GamepadController(InputController):
|
||||
|
||||
self.joystick = pygame.joystick.Joystick(0)
|
||||
self.joystick.init()
|
||||
joystick_name = self.joystick.get_name()
|
||||
self.is_xbox = self._detect_xbox(joystick_name)
|
||||
self._xbox360_profile = joystick_name == "Xbox 360 Controller"
|
||||
if self._xbox360_profile:
|
||||
# gym-hil "Xbox 360 Controller" profile
|
||||
self.AXIS_RIGHT_X = 3
|
||||
self.AXIS_RIGHT_Y = 4
|
||||
self.BUTTON_LT = self.XBOX_BUTTON_LT
|
||||
self.BUTTON_RT = self.XBOX_BUTTON_RT
|
||||
self._invert_left_x = True
|
||||
else:
|
||||
# gym-hil default profile
|
||||
self.AXIS_RIGHT_X = 2
|
||||
self.AXIS_RIGHT_Y = 3
|
||||
self.BUTTON_LT = 6
|
||||
self.BUTTON_RT = 7
|
||||
self._invert_left_x = False
|
||||
logging.info(f"Initialized gamepad: {joystick_name} (xbox={self.is_xbox})")
|
||||
logging.info(f"Initialized gamepad: {self.joystick.get_name()}")
|
||||
|
||||
print("Gamepad controls:")
|
||||
print(" Left analog stick: Move in X-Y plane")
|
||||
print(" Right analog stick (vertical): Move in Z axis")
|
||||
print(" RB: Intervention toggle")
|
||||
print(" LT / RT: Close / Open gripper")
|
||||
print(" Y: End episode with SUCCESS")
|
||||
print(" A: End episode with FAILURE")
|
||||
print(" X: Rerecord episode")
|
||||
print(" B/Circle button: Exit")
|
||||
print(" Y/Triangle button: End episode with SUCCESS")
|
||||
print(" A/Cross button: End episode with FAILURE")
|
||||
print(" X/Square button: Rerecord episode")
|
||||
|
||||
def stop(self):
|
||||
"""Clean up pygame resources."""
|
||||
import pygame
|
||||
|
||||
if pygame.joystick.get_init():
|
||||
@@ -282,56 +239,67 @@ class GamepadController(InputController):
|
||||
pygame.quit()
|
||||
|
||||
def update(self):
|
||||
"""Process pygame events to get fresh gamepad readings."""
|
||||
import pygame
|
||||
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.JOYBUTTONDOWN:
|
||||
if event.button == self.BUTTON_Y:
|
||||
if event.button == 3:
|
||||
self.episode_end_status = TeleopEvents.SUCCESS
|
||||
elif event.button == self.BUTTON_A:
|
||||
# A button (1) for failure
|
||||
elif event.button == 1:
|
||||
self.episode_end_status = TeleopEvents.FAILURE
|
||||
elif event.button == self.BUTTON_X:
|
||||
# X button (0) for rerecord
|
||||
elif event.button == 0:
|
||||
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
|
||||
elif event.button == self.BUTTON_LT:
|
||||
|
||||
# RB button (6) for closing gripper
|
||||
elif event.button == 6:
|
||||
self.close_gripper_command = True
|
||||
elif event.button == self.BUTTON_RT:
|
||||
|
||||
# LT button (7) for opening gripper
|
||||
elif event.button == 7:
|
||||
self.open_gripper_command = True
|
||||
|
||||
# Reset episode status on button release
|
||||
elif event.type == pygame.JOYBUTTONUP:
|
||||
if event.button in [self.BUTTON_Y, self.BUTTON_A, self.BUTTON_X]:
|
||||
if event.button in [0, 2, 3]:
|
||||
self.episode_end_status = None
|
||||
elif event.button == self.BUTTON_LT:
|
||||
|
||||
elif event.button == 6:
|
||||
self.close_gripper_command = False
|
||||
elif event.button == self.BUTTON_RT:
|
||||
|
||||
elif event.button == 7:
|
||||
self.open_gripper_command = False
|
||||
|
||||
if self.joystick.get_button(self.BUTTON_RB):
|
||||
# Check for RB button (typically button 5) for intervention flag
|
||||
if self.joystick.get_button(5):
|
||||
self.intervention_flag = True
|
||||
else:
|
||||
self.intervention_flag = False
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas from gamepad state."""
|
||||
import pygame
|
||||
|
||||
try:
|
||||
x_input = self.joystick.get_axis(self.AXIS_LEFT_X)
|
||||
y_input = self.joystick.get_axis(self.AXIS_LEFT_Y)
|
||||
z_input = self.joystick.get_axis(self.AXIS_RIGHT_Y)
|
||||
# Read joystick axes
|
||||
# Left stick X and Y (typically axes 0 and 1)
|
||||
y_input = self.joystick.get_axis(0) # Up/Down (often inverted)
|
||||
x_input = self.joystick.get_axis(1) # Left/Right
|
||||
|
||||
# Right stick Y (typically axis 3 or 4)
|
||||
z_input = self.joystick.get_axis(3) # Up/Down for Z
|
||||
|
||||
# Apply deadzone to avoid drift
|
||||
x_input = 0 if abs(x_input) < self.deadzone else x_input
|
||||
y_input = 0 if abs(y_input) < self.deadzone else y_input
|
||||
z_input = 0 if abs(z_input) < self.deadzone else z_input
|
||||
|
||||
if self._invert_left_x:
|
||||
x_input = -x_input
|
||||
if self._invert_left_y:
|
||||
y_input = -y_input
|
||||
if self._invert_right_y:
|
||||
z_input = -z_input
|
||||
|
||||
delta_x = y_input * self.y_step_size
|
||||
delta_y = x_input * self.x_step_size
|
||||
delta_z = z_input * self.z_step_size
|
||||
# Calculate deltas (note: may need to invert axes depending on controller)
|
||||
delta_x = -x_input * self.x_step_size # Forward/backward
|
||||
delta_y = -y_input * self.y_step_size # Left/right
|
||||
delta_z = -z_input * self.z_step_size # Up/down
|
||||
|
||||
return delta_x, delta_y, delta_z
|
||||
|
||||
@@ -341,15 +309,7 @@ class GamepadController(InputController):
|
||||
|
||||
|
||||
class GamepadControllerHID(InputController):
|
||||
"""Generate motion deltas from gamepad input using HIDAPI.
|
||||
|
||||
Supports auto-detection of controller type for correct HID report parsing.
|
||||
Currently supported: Logitech RumblePad 2, 8BitDo Ultimate 2C Wireless.
|
||||
"""
|
||||
|
||||
CONTROLLER_LOGITECH = "logitech"
|
||||
CONTROLLER_8BITDO = "8bitdo"
|
||||
CONTROLLER_UNKNOWN = "unknown"
|
||||
"""Generate motion deltas from gamepad input using HIDAPI."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -358,26 +318,36 @@ class GamepadControllerHID(InputController):
|
||||
z_step_size=1.0,
|
||||
deadzone=0.1,
|
||||
):
|
||||
"""
|
||||
Initialize the HID gamepad controller.
|
||||
|
||||
Args:
|
||||
step_size: Base movement step size in meters
|
||||
z_scale: Scaling factor for Z-axis movement
|
||||
deadzone: Joystick deadzone to prevent drift
|
||||
"""
|
||||
super().__init__(x_step_size, y_step_size, z_step_size)
|
||||
self.deadzone = deadzone
|
||||
self.device = None
|
||||
self.device_info = None
|
||||
self.controller_type = self.CONTROLLER_UNKNOWN
|
||||
|
||||
# Movement values (normalized from -1.0 to 1.0)
|
||||
self.left_x = 0.0
|
||||
self.left_y = 0.0
|
||||
self.right_x = 0.0
|
||||
self.right_y = 0.0
|
||||
|
||||
# Button states
|
||||
self.buttons = {}
|
||||
|
||||
def find_device(self):
|
||||
"""Look for the gamepad device by vendor and product ID."""
|
||||
import hid
|
||||
|
||||
devices = hid.enumerate()
|
||||
for device in devices:
|
||||
device_name = device["product_string"]
|
||||
if any(controller in device_name for controller in ["Logitech", "Xbox", "PS4", "PS5", "8BitDo"]):
|
||||
if any(controller in device_name for controller in ["Logitech", "Xbox", "PS4", "PS5"]):
|
||||
return device
|
||||
|
||||
logging.error(
|
||||
@@ -385,15 +355,8 @@ class GamepadControllerHID(InputController):
|
||||
)
|
||||
return None
|
||||
|
||||
def _detect_controller_type(self, product_string):
|
||||
product = product_string.lower() if product_string else ""
|
||||
if "8bitdo" in product:
|
||||
return self.CONTROLLER_8BITDO
|
||||
elif "logitech" in product:
|
||||
return self.CONTROLLER_LOGITECH
|
||||
return self.CONTROLLER_UNKNOWN
|
||||
|
||||
def start(self):
|
||||
"""Connect to the gamepad using HIDAPI."""
|
||||
import hid
|
||||
|
||||
self.device_info = self.find_device()
|
||||
@@ -411,22 +374,12 @@ class GamepadControllerHID(InputController):
|
||||
product = self.device.get_product_string()
|
||||
logging.info(f"Connected to {manufacturer} {product}")
|
||||
|
||||
self.controller_type = self._detect_controller_type(product)
|
||||
logging.info(f"Detected controller type: {self.controller_type}")
|
||||
|
||||
print("Gamepad controls (HID mode):")
|
||||
print(" Left analog stick: Move in X-Y plane")
|
||||
print(" Right analog stick: Move in Z axis (vertical)")
|
||||
print(" RB: Intervention toggle")
|
||||
if self.controller_type == self.CONTROLLER_8BITDO:
|
||||
print(" L3 (left stick click): Close gripper")
|
||||
print(" R3 (right stick click): Open gripper")
|
||||
else:
|
||||
print(" LT: Close gripper")
|
||||
print(" RT: Open gripper")
|
||||
print(" Y: End episode with SUCCESS")
|
||||
print(" X: End episode with FAILURE")
|
||||
print(" A: Rerecord episode")
|
||||
logging.info("Gamepad controls (HID mode):")
|
||||
logging.info(" Left analog stick: Move in X-Y plane")
|
||||
logging.info(" Right analog stick: Move in Z axis (vertical)")
|
||||
logging.info(" Button 1/B/Circle: Exit")
|
||||
logging.info(" Button 2/A/Cross: End episode with SUCCESS")
|
||||
logging.info(" Button 3/X/Square: End episode with FAILURE")
|
||||
|
||||
except OSError as e:
|
||||
logging.error(f"Error opening gamepad: {e}")
|
||||
@@ -434,124 +387,74 @@ class GamepadControllerHID(InputController):
|
||||
self.running = False
|
||||
|
||||
def stop(self):
|
||||
"""Close the HID device connection."""
|
||||
if self.device:
|
||||
self.device.close()
|
||||
self.device = None
|
||||
|
||||
def update(self):
|
||||
"""Read the device several times to drain the HID buffer and get a stable reading."""
|
||||
"""
|
||||
Read and process the latest gamepad data.
|
||||
Due to an issue with the HIDAPI, we need to read the read the device several times in order to get a stable reading
|
||||
"""
|
||||
for _ in range(10):
|
||||
self._update()
|
||||
|
||||
def _update(self):
|
||||
"""Read and process the latest gamepad data."""
|
||||
if not self.device or not self.running:
|
||||
return
|
||||
|
||||
try:
|
||||
# Read data from the gamepad
|
||||
data = self.device.read(64)
|
||||
if not data:
|
||||
return
|
||||
# Interpret gamepad data - this will vary by controller model
|
||||
# These offsets are for the Logitech RumblePad 2
|
||||
if data and len(data) >= 8:
|
||||
# Normalize joystick values from 0-255 to -1.0-1.0
|
||||
self.left_y = (data[1] - 128) / 128.0
|
||||
self.left_x = (data[2] - 128) / 128.0
|
||||
self.right_x = (data[3] - 128) / 128.0
|
||||
self.right_y = (data[4] - 128) / 128.0
|
||||
|
||||
if self.controller_type == self.CONTROLLER_8BITDO:
|
||||
self._parse_8bitdo(data)
|
||||
else:
|
||||
self._parse_logitech(data)
|
||||
# Apply deadzone
|
||||
self.left_y = 0 if abs(self.left_y) < self.deadzone else self.left_y
|
||||
self.left_x = 0 if abs(self.left_x) < self.deadzone else self.left_x
|
||||
self.right_x = 0 if abs(self.right_x) < self.deadzone else self.right_x
|
||||
self.right_y = 0 if abs(self.right_y) < self.deadzone else self.right_y
|
||||
|
||||
# Parse button states (byte 5 in the Logitech RumblePad 2)
|
||||
buttons = data[5]
|
||||
|
||||
# Check if RB is pressed then the intervention flag should be set
|
||||
self.intervention_flag = data[6] in [2, 6, 10, 14]
|
||||
|
||||
# Check if RT is pressed
|
||||
self.open_gripper_command = data[6] in [8, 10, 12]
|
||||
|
||||
# Check if LT is pressed
|
||||
self.close_gripper_command = data[6] in [4, 6, 12]
|
||||
|
||||
# Check if Y/Triangle button (bit 7) is pressed for saving
|
||||
# Check if X/Square button (bit 5) is pressed for failure
|
||||
# Check if A/Cross button (bit 4) is pressed for rerecording
|
||||
if buttons & 1 << 7:
|
||||
self.episode_end_status = TeleopEvents.SUCCESS
|
||||
elif buttons & 1 << 5:
|
||||
self.episode_end_status = TeleopEvents.FAILURE
|
||||
elif buttons & 1 << 4:
|
||||
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
|
||||
else:
|
||||
self.episode_end_status = None
|
||||
|
||||
except OSError as e:
|
||||
logging.error(f"Error reading from gamepad: {e}")
|
||||
|
||||
def _apply_deadzone(self):
|
||||
self.left_x = 0 if abs(self.left_x) < self.deadzone else self.left_x
|
||||
self.left_y = 0 if abs(self.left_y) < self.deadzone else self.left_y
|
||||
self.right_x = 0 if abs(self.right_x) < self.deadzone else self.right_x
|
||||
self.right_y = 0 if abs(self.right_y) < self.deadzone else self.right_y
|
||||
|
||||
def _parse_8bitdo(self, data):
|
||||
"""Parse HID report from 8BitDo Ultimate 2C Wireless (Bluetooth on macOS).
|
||||
|
||||
11-byte report layout:
|
||||
byte[0]: Report ID (0x01)
|
||||
byte[1]: D-pad hat switch (0=N, 2=E, 5=S, 6=W, 15=neutral)
|
||||
byte[2]: Left Stick X (0=left, 127=center, 255=right)
|
||||
byte[3]: Left Stick Y (0=up, 127=center, 255=down)
|
||||
byte[4]: Right Stick X (inverted: 255=left, 0=right)
|
||||
byte[5]: Right Stick Y (0=up, 127=center, 255=down)
|
||||
byte[6]: RT analog trigger (0-255)
|
||||
byte[7]: LT analog trigger (0-255)
|
||||
byte[8]: Buttons -- bit0=A, bit1=B, bit3=X, bit4=Y, bit6=LB, bit7=RB
|
||||
byte[9]: System -- bit0=LT(digital), bit1=RT(digital), bit3=Select,
|
||||
bit4=Start, bit5=L3, bit6=R3
|
||||
byte[10]: Unused
|
||||
"""
|
||||
if len(data) < 11:
|
||||
return
|
||||
|
||||
self.left_x = (data[2] - 127) / 128.0
|
||||
self.left_y = (data[3] - 127) / 128.0
|
||||
self.right_x = -(data[4] - 127) / 128.0
|
||||
self.right_y = (data[5] - 127) / 128.0
|
||||
|
||||
self._apply_deadzone()
|
||||
|
||||
buttons = data[8]
|
||||
|
||||
# RB (bit 7) = intervention
|
||||
self.intervention_flag = bool(buttons & 0x80)
|
||||
|
||||
# Stick clicks for gripper: R3 (byte[9] bit6) = open, L3 (byte[9] bit5) = close
|
||||
system = data[9]
|
||||
self.open_gripper_command = bool(system & 0x40) # R3
|
||||
self.close_gripper_command = bool(system & 0x20) # L3
|
||||
|
||||
# Y (bit 4) = success, X (bit 3) = failure, A (bit 0) = rerecord
|
||||
if buttons & 0x10:
|
||||
self.episode_end_status = TeleopEvents.SUCCESS
|
||||
elif buttons & 0x08:
|
||||
self.episode_end_status = TeleopEvents.FAILURE
|
||||
elif buttons & 0x01:
|
||||
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
|
||||
else:
|
||||
self.episode_end_status = None
|
||||
|
||||
def _parse_logitech(self, data):
|
||||
"""Parse HID report from Logitech RumblePad 2 (and similar Logitech gamepads).
|
||||
|
||||
Report layout (8+ bytes):
|
||||
byte[1]: Left Stick X (0-255, center=128)
|
||||
byte[2]: Left Stick Y (0-255, center=128)
|
||||
byte[3]: Right Stick X (0-255, center=128)
|
||||
byte[4]: Right Stick Y (0-255, center=128)
|
||||
byte[5]: Face buttons bitmask
|
||||
byte[6]: Shoulder/trigger buttons bitmask
|
||||
"""
|
||||
if len(data) < 8:
|
||||
return
|
||||
|
||||
self.left_x = (data[1] - 128) / 128.0
|
||||
self.left_y = (data[2] - 128) / 128.0
|
||||
self.right_x = (data[3] - 128) / 128.0
|
||||
self.right_y = (data[4] - 128) / 128.0
|
||||
|
||||
self._apply_deadzone()
|
||||
|
||||
buttons = data[5]
|
||||
|
||||
self.intervention_flag = data[6] in [2, 6, 10, 14]
|
||||
self.open_gripper_command = data[6] in [8, 10, 12]
|
||||
self.close_gripper_command = data[6] in [4, 6, 12]
|
||||
|
||||
if buttons & 1 << 7:
|
||||
self.episode_end_status = TeleopEvents.SUCCESS
|
||||
elif buttons & 1 << 5:
|
||||
self.episode_end_status = TeleopEvents.FAILURE
|
||||
elif buttons & 1 << 4:
|
||||
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
|
||||
else:
|
||||
self.episode_end_status = None
|
||||
|
||||
def get_deltas(self):
|
||||
delta_x = -self.left_y * self.x_step_size
|
||||
delta_y = -self.left_x * self.y_step_size
|
||||
delta_z = -self.right_y * self.z_step_size
|
||||
"""Get the current movement deltas from gamepad state."""
|
||||
# Calculate deltas - invert as needed based on controller orientation
|
||||
delta_x = -self.left_x * self.x_step_size # Forward/backward
|
||||
delta_y = -self.left_y * self.y_step_size # Left/right
|
||||
delta_z = -self.right_y * self.z_step_size # Up/down
|
||||
|
||||
return delta_x, delta_y, delta_z
|
||||
|
||||
@@ -341,8 +341,8 @@ class KeyboardRoverTeleop(KeyboardTeleop):
|
||||
def action_features(self) -> dict:
|
||||
"""Return action format for rover (linear and angular velocities)."""
|
||||
return {
|
||||
"linear_velocity": float,
|
||||
"angular_velocity": float,
|
||||
"linear.vel": float,
|
||||
"angular.vel": float,
|
||||
}
|
||||
|
||||
@property
|
||||
@@ -366,7 +366,7 @@ class KeyboardRoverTeleop(KeyboardTeleop):
|
||||
Get the current action based on pressed keys.
|
||||
|
||||
Returns:
|
||||
RobotAction with 'linear_velocity' and 'angular_velocity' keys.
|
||||
RobotAction with 'linear.vel' and 'angular.vel' keys
|
||||
"""
|
||||
before_read_t = time.perf_counter()
|
||||
|
||||
@@ -427,6 +427,6 @@ class KeyboardRoverTeleop(KeyboardTeleop):
|
||||
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
||||
|
||||
return {
|
||||
"linear_velocity": linear_velocity,
|
||||
"angular_velocity": angular_velocity,
|
||||
"linear.vel": linear_velocity,
|
||||
"angular.vel": angular_velocity,
|
||||
}
|
||||
|
||||
@@ -42,8 +42,6 @@ from lerobot.policies.factory import (
|
||||
make_pre_post_processors,
|
||||
)
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.policies.vqbet.modeling_vqbet import VQBeTHead
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.random_utils import seeded_context
|
||||
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
|
||||
@@ -462,45 +460,3 @@ def test_act_temporal_ensembler():
|
||||
assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max"))
|
||||
# Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error.
|
||||
torch.testing.assert_close(online_avg, offline_avg, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
def test_vqbet_discretize_keeps_buffers_on_device():
|
||||
"""Regression test: VQBeTHead.discretize() must not move registered buffers off the model device.
|
||||
|
||||
Previously, `self.vqvae_model.discretized = torch.tensor(True)` replaced the
|
||||
registered buffer with a new CPU tensor, causing DDP to crash with:
|
||||
RuntimeError: No backend type associated with device type cpu
|
||||
The fix uses `.fill_(True)` to update in-place, preserving device placement.
|
||||
"""
|
||||
config = VQBeTConfig()
|
||||
config.input_features = {
|
||||
OBS_IMAGES: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 96, 96)),
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(6,)),
|
||||
}
|
||||
config.output_features = {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(6,)),
|
||||
}
|
||||
# Tiny sizes for fast CPU/GPU execution.
|
||||
config.n_vqvae_training_steps = 3
|
||||
config.vqvae_n_embed = 8
|
||||
config.vqvae_embedding_dim = 32
|
||||
config.vqvae_enc_hidden_dim = 32
|
||||
config.action_chunk_size = 2
|
||||
config.crop_shape = (84, 84)
|
||||
|
||||
head = VQBeTHead(config).to(DEVICE)
|
||||
vqvae = head.vqvae_model
|
||||
|
||||
dummy_actions = torch.randn(4, config.action_chunk_size, config.action_feature.shape[0], device=DEVICE)
|
||||
n_steps = config.n_vqvae_training_steps
|
||||
for _ in range(n_steps):
|
||||
head.discretize(n_steps, dummy_actions)
|
||||
|
||||
assert vqvae.discretized.device.type == torch.device(DEVICE).type, (
|
||||
"vqvae_model.discretized was moved off the model device after discretize(). "
|
||||
"Use .fill_(True) instead of = torch.tensor(True) to keep the buffer on device."
|
||||
)
|
||||
assert vqvae.vq_layer.freeze_codebook.device.type == torch.device(DEVICE).type, (
|
||||
"vq_layer.freeze_codebook was moved off the model device after discretize(). "
|
||||
"Use .fill_(True) instead of = torch.tensor(True) to keep the buffer on device."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user