Compare commits

...

3 Commits

Author SHA1 Message Date
Pepijn
727ca1a92c chore(rollout): remove speculative action order fix 2026-04-24 16:24:44 +02:00
Pepijn
ee737b72d0 fix(rollout): avoid oversampling observations during interpolation 2026-04-24 15:42:36 +02:00
Pepijn
cf2e42f557 fix(rollout): preserve relative action chunks 2026-04-24 15:33:16 +02:00
4 changed files with 212 additions and 15 deletions

View File

@@ -109,6 +109,21 @@ def _normalize_prev_actions_length(prev_actions: torch.Tensor, target_steps: int
return padded
def _get_current_raw_state(
relative_step: RelativeActionsProcessorStep,
fallback_state: torch.Tensor | None,
) -> torch.Tensor | None:
"""Return the current raw state cached by the relative-action step.
``RelativeActionsProcessorStep`` caches the observation state before any
observation normalization. Re-anchoring RTC leftovers must use that raw
state rather than the normalized observation that the policy consumes.
"""
if relative_step._last_state is not None:
return relative_step._last_state
return fallback_state
# ---------------------------------------------------------------------------
# RTCInferenceEngine
# ---------------------------------------------------------------------------
@@ -318,7 +333,9 @@ class RTCInferenceEngine(InferenceEngine):
preprocessed = self._preprocessor(obs_batch)
if prev_actions is not None and self._relative_step is not None:
state_tensor = preprocessed.get(OBS_STATE)
state_tensor = _get_current_raw_state(
self._relative_step, obs_batch.get(OBS_STATE)
)
if state_tensor is not None:
prev_abs = queue.get_processed_left_over()
if prev_abs is not None and prev_abs.numel() > 0:

View File

@@ -17,6 +17,7 @@
from __future__ import annotations
import logging
from collections import deque
from contextlib import nullcontext
from copy import copy
@@ -34,9 +35,9 @@ logger = logging.getLogger(__name__)
class SyncInferenceEngine(InferenceEngine):
"""Inline synchronous inference: compute one action per call.
``get_action`` runs the full policy pipeline (pre/post-processor +
``select_action``) on the given observation frame and returns a
CPU action tensor reordered to match the dataset action keys.
``get_action`` runs the full policy pipeline when its local action
queue is empty, postprocesses the whole predicted chunk immediately,
and then returns one already-postprocessed CPU action at a time.
"""
def __init__(
@@ -58,6 +59,8 @@ class SyncInferenceEngine(InferenceEngine):
self._task = task
self._device = torch.device(device or "cpu")
self._robot_type = robot_type
self._processed_action_queue: deque[torch.Tensor] = deque()
logger.info(
"SyncInferenceEngine initialized (device=%s, action_keys=%d)",
self._device,
@@ -78,9 +81,28 @@ class SyncInferenceEngine(InferenceEngine):
self._policy.reset()
self._preprocessor.reset()
self._postprocessor.reset()
self._processed_action_queue.clear()
def _enqueue_processed_chunk(self, action_chunk: torch.Tensor) -> None:
"""Convert a postprocessed action chunk into ordered per-step CPU tensors."""
if action_chunk.ndim == 2:
action_chunk = action_chunk.unsqueeze(0)
n_action_steps = getattr(self._policy.config, "n_action_steps", action_chunk.shape[1])
action_chunk = action_chunk[:, : min(n_action_steps, action_chunk.shape[1])]
for action in action_chunk.squeeze(0):
action_tensor = action.cpu()
action_dict = make_robot_action(action_tensor, self._dataset_features)
ordered_action = torch.tensor(
[action_dict[k] for k in self._ordered_action_keys], dtype=action_tensor.dtype
)
self._processed_action_queue.append(ordered_action)
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
"""Run the full inference pipeline on ``obs_frame`` and return an action tensor."""
if self._processed_action_queue:
return self._processed_action_queue.popleft().clone()
if obs_frame is None:
return None
# Shallow copy is intentional: the caller (`send_next_action`) builds
@@ -97,11 +119,10 @@ class SyncInferenceEngine(InferenceEngine):
observation, self._device, self._task, self._robot_type
)
observation = self._preprocessor(observation)
action = self._policy.select_action(observation)
action = self._postprocessor(action)
action_tensor = action.squeeze(0).cpu()
action_chunk = self._policy.predict_action_chunk(observation)
processed_chunk = self._postprocessor(action_chunk)
# Reorder to match dataset action ordering so the caller can treat
# the returned tensor uniformly across backends.
action_dict = make_robot_action(action_tensor, self._dataset_features)
return torch.tensor([action_dict[k] for k in self._ordered_action_keys])
self._enqueue_processed_chunk(processed_chunk)
if not self._processed_action_queue:
return None
return self._processed_action_queue.popleft().clone()

View File

@@ -47,8 +47,12 @@ class BaseStrategy(RolloutStrategy):
interpolator = self._interpolator
control_interval = interpolator.get_control_interval(cfg.fps)
observation_interval = 1.0 / cfg.fps
start_time = time.perf_counter()
next_observation_time = 0.0
obs = None
obs_processed = None
engine.resume()
logger.info("Base strategy control loop started")
@@ -59,13 +63,18 @@ class BaseStrategy(RolloutStrategy):
logger.info("Duration limit reached (%.0fs)", cfg.duration)
break
obs = robot.get_observation()
obs_processed = ctx.processors.robot_observation_processor(obs)
engine.notify_observation(obs_processed)
if obs is None or loop_start >= next_observation_time:
obs = robot.get_observation()
obs_processed = ctx.processors.robot_observation_processor(obs)
engine.notify_observation(obs_processed)
next_observation_time = loop_start + observation_interval
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
continue
if obs_processed is None:
continue
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
self._log_telemetry(obs_processed, action_dict, ctx.runtime)

View File

@@ -13,7 +13,9 @@ Flow under test:
import importlib.util
import sys
from pathlib import Path
from types import ModuleType, SimpleNamespace
import numpy as np
import torch
from lerobot.configs.types import (
@@ -22,7 +24,13 @@ from lerobot.configs.types import (
PolicyFeature,
RTCAttentionSchedule,
)
from lerobot.processor import TransitionKey, batch_to_transition
from lerobot.processor import (
PolicyProcessorPipeline,
TransitionKey,
batch_to_transition,
policy_action_to_transition,
transition_to_policy_action,
)
from lerobot.processor.normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep
from lerobot.processor.relative_action_processor import (
AbsoluteActionsProcessorStep,
@@ -52,6 +60,34 @@ _rtc_debug_mod = _import_rtc_module("lerobot.policies.rtc.debug_tracker", "debug
_rtc_mod = _import_rtc_module("lerobot.policies.rtc.modeling_rtc", "modeling_rtc.py")
RTCProcessor = _rtc_mod.RTCProcessor
def _ensure_rollout_test_packages() -> Path:
rollout_dir = Path(__file__).resolve().parents[3] / "src" / "lerobot" / "rollout"
rollout_pkg = sys.modules.setdefault("lerobot.rollout", ModuleType("lerobot.rollout"))
rollout_pkg.__path__ = [str(rollout_dir)]
inference_pkg = sys.modules.setdefault(
"lerobot.rollout.inference", ModuleType("lerobot.rollout.inference")
)
inference_pkg.__path__ = [str(rollout_dir / "inference")]
return rollout_dir
def _import_rollout_module(module_name: str, relative_path: str):
rollout_dir = _ensure_rollout_test_packages()
spec = importlib.util.spec_from_file_location(module_name, rollout_dir / relative_path)
mod = importlib.util.module_from_spec(spec)
sys.modules[module_name] = mod
spec.loader.exec_module(mod)
return mod
_rollout_robot_wrapper_mod = _import_rollout_module("lerobot.rollout.robot_wrapper", "robot_wrapper.py")
_rollout_base_mod = _import_rollout_module("lerobot.rollout.inference.base", "inference/base.py")
_rollout_sync_mod = _import_rollout_module("lerobot.rollout.inference.sync", "inference/sync.py")
_rollout_rtc_mod = _import_rollout_module("lerobot.rollout.inference.rtc", "inference/rtc.py")
SyncInferenceEngine = _rollout_sync_mod.SyncInferenceEngine
get_current_raw_state = _rollout_rtc_mod._get_current_raw_state
ACTION_DIM = 6
CHUNK_SIZE = 50
EXECUTION_HORIZON = 10
@@ -89,6 +125,44 @@ def _make_relative_pipeline(action_dim=ACTION_DIM, norm_mode=NormalizationMode.M
return relative_step, normalizer, unnormalizer, absolute_step
def _make_relative_sync_pipelines(
action_dim=ACTION_DIM,
action_names: list[str] | None = None,
exclude_joints: list[str] | None = None,
):
relative_step = RelativeActionsProcessorStep(
enabled=True,
exclude_joints=exclude_joints or [],
action_names=action_names or [f"joint_{i}.pos" for i in range(action_dim)],
)
absolute_step = AbsoluteActionsProcessorStep(enabled=True, relative_step=relative_step)
preprocessor = PolicyProcessorPipeline(steps=[relative_step], name="test_preprocessor")
postprocessor = PolicyProcessorPipeline(
steps=[absolute_step],
name="test_postprocessor",
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
)
return relative_step, preprocessor, postprocessor
class _ChunkPolicyStub:
def __init__(self, action_dim: int, n_action_steps: int):
self.config = SimpleNamespace(use_amp=False, n_action_steps=n_action_steps)
self._chunk = torch.zeros(1, n_action_steps, action_dim)
self.predict_calls = 0
def reset(self):
return None
def predict_action_chunk(self, batch):
self.predict_calls += 1
return self._chunk.clone()
def select_action(self, batch):
raise AssertionError("SyncInferenceEngine should consume chunk outputs directly")
class TestActionQueueRelativeActions:
"""Verify ActionQueue stores model-space (relative) actions for RTC and absolute for robot."""
@@ -120,6 +194,82 @@ class TestActionQueueRelativeActions:
torch.testing.assert_close(first_action, absolute_actions[0])
class TestRolloutInferenceRelativeActions:
"""Regression tests for rollout inference engines with relative-action policies."""
def test_sync_engine_postprocesses_chunk_before_queueing(self):
"""Queued sync actions must stay anchored to the state from the chunk-producing step."""
_, preprocessor, postprocessor = _make_relative_sync_pipelines(ACTION_DIM)
policy = _ChunkPolicyStub(action_dim=ACTION_DIM, n_action_steps=3)
ordered_action_keys = [f"joint_{i}.pos" for i in range(ACTION_DIM)]
dataset_features = {ACTION: {"names": ordered_action_keys}}
engine = SyncInferenceEngine(
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset_features=dataset_features,
ordered_action_keys=ordered_action_keys,
task="test",
device="cpu",
robot_type="mock",
)
state_1 = np.arange(1, ACTION_DIM + 1, dtype=np.float32)
state_2 = 10 * state_1
action_1 = engine.get_action({OBS_STATE: state_1.copy()})
action_2 = engine.get_action({OBS_STATE: state_2.copy()})
torch.testing.assert_close(action_1, torch.from_numpy(state_1))
torch.testing.assert_close(action_2, torch.from_numpy(state_1))
assert policy.predict_calls == 1
def test_rtc_reanchoring_prefers_raw_cached_state(self):
"""RTC re-anchoring must use the raw state cached before observation normalization."""
action_dim = ACTION_DIM
relative_step = RelativeActionsProcessorStep(
enabled=True,
action_names=[f"joint_{i}.pos" for i in range(action_dim)],
)
features = {
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(action_dim,)),
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,)),
}
stats = {
OBS_STATE: {
"mean": np.arange(1, action_dim + 1, dtype=np.float32),
"std": 2 * np.ones(action_dim, dtype=np.float32),
},
ACTION: {
"mean": np.zeros(action_dim, dtype=np.float32),
"std": np.ones(action_dim, dtype=np.float32),
},
}
preprocessor = PolicyProcessorPipeline(
steps=[
relative_step,
NormalizerProcessorStep(
features=features,
norm_map={
FeatureType.STATE: NormalizationMode.MEAN_STD,
FeatureType.ACTION: NormalizationMode.MEAN_STD,
},
stats=stats,
),
],
name="test_preprocessor_with_state_norm",
)
raw_state = torch.from_numpy(np.arange(5, 5 + action_dim, dtype=np.float32)).unsqueeze(0)
preprocessed = preprocessor({OBS_STATE: raw_state.clone()})
current_state = get_current_raw_state(relative_step, preprocessed.get(OBS_STATE))
torch.testing.assert_close(current_state, raw_state)
assert not torch.allclose(preprocessed[OBS_STATE], raw_state)
class TestRTCDenoiseWithRelativeLeftovers:
"""Verify RTC denoise_step correctly handles relative-space prev_chunk_left_over."""