Compare commits

...

3 Commits

Author SHA1 Message Date
Pepijn
727ca1a92c chore(rollout): remove speculative action order fix 2026-04-24 16:24:44 +02:00
Pepijn
ee737b72d0 fix(rollout): avoid oversampling observations during interpolation 2026-04-24 15:42:36 +02:00
Pepijn
cf2e42f557 fix(rollout): preserve relative action chunks 2026-04-24 15:33:16 +02:00
4 changed files with 212 additions and 15 deletions

View File

@@ -109,6 +109,21 @@ def _normalize_prev_actions_length(prev_actions: torch.Tensor, target_steps: int
return padded return padded
def _get_current_raw_state(
relative_step: RelativeActionsProcessorStep,
fallback_state: torch.Tensor | None,
) -> torch.Tensor | None:
"""Return the current raw state cached by the relative-action step.
``RelativeActionsProcessorStep`` caches the observation state before any
observation normalization. Re-anchoring RTC leftovers must use that raw
state rather than the normalized observation that the policy consumes.
"""
if relative_step._last_state is not None:
return relative_step._last_state
return fallback_state
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# RTCInferenceEngine # RTCInferenceEngine
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -318,7 +333,9 @@ class RTCInferenceEngine(InferenceEngine):
preprocessed = self._preprocessor(obs_batch) preprocessed = self._preprocessor(obs_batch)
if prev_actions is not None and self._relative_step is not None: if prev_actions is not None and self._relative_step is not None:
state_tensor = preprocessed.get(OBS_STATE) state_tensor = _get_current_raw_state(
self._relative_step, obs_batch.get(OBS_STATE)
)
if state_tensor is not None: if state_tensor is not None:
prev_abs = queue.get_processed_left_over() prev_abs = queue.get_processed_left_over()
if prev_abs is not None and prev_abs.numel() > 0: if prev_abs is not None and prev_abs.numel() > 0:

View File

@@ -17,6 +17,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from collections import deque
from contextlib import nullcontext from contextlib import nullcontext
from copy import copy from copy import copy
@@ -34,9 +35,9 @@ logger = logging.getLogger(__name__)
class SyncInferenceEngine(InferenceEngine): class SyncInferenceEngine(InferenceEngine):
"""Inline synchronous inference: compute one action per call. """Inline synchronous inference: compute one action per call.
``get_action`` runs the full policy pipeline (pre/post-processor + ``get_action`` runs the full policy pipeline when its local action
``select_action``) on the given observation frame and returns a queue is empty, postprocesses the whole predicted chunk immediately,
CPU action tensor reordered to match the dataset action keys. and then returns one already-postprocessed CPU action at a time.
""" """
def __init__( def __init__(
@@ -58,6 +59,8 @@ class SyncInferenceEngine(InferenceEngine):
self._task = task self._task = task
self._device = torch.device(device or "cpu") self._device = torch.device(device or "cpu")
self._robot_type = robot_type self._robot_type = robot_type
self._processed_action_queue: deque[torch.Tensor] = deque()
logger.info( logger.info(
"SyncInferenceEngine initialized (device=%s, action_keys=%d)", "SyncInferenceEngine initialized (device=%s, action_keys=%d)",
self._device, self._device,
@@ -78,9 +81,28 @@ class SyncInferenceEngine(InferenceEngine):
self._policy.reset() self._policy.reset()
self._preprocessor.reset() self._preprocessor.reset()
self._postprocessor.reset() self._postprocessor.reset()
self._processed_action_queue.clear()
def _enqueue_processed_chunk(self, action_chunk: torch.Tensor) -> None:
"""Convert a postprocessed action chunk into ordered per-step CPU tensors."""
if action_chunk.ndim == 2:
action_chunk = action_chunk.unsqueeze(0)
n_action_steps = getattr(self._policy.config, "n_action_steps", action_chunk.shape[1])
action_chunk = action_chunk[:, : min(n_action_steps, action_chunk.shape[1])]
for action in action_chunk.squeeze(0):
action_tensor = action.cpu()
action_dict = make_robot_action(action_tensor, self._dataset_features)
ordered_action = torch.tensor(
[action_dict[k] for k in self._ordered_action_keys], dtype=action_tensor.dtype
)
self._processed_action_queue.append(ordered_action)
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None: def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
"""Run the full inference pipeline on ``obs_frame`` and return an action tensor.""" """Run the full inference pipeline on ``obs_frame`` and return an action tensor."""
if self._processed_action_queue:
return self._processed_action_queue.popleft().clone()
if obs_frame is None: if obs_frame is None:
return None return None
# Shallow copy is intentional: the caller (`send_next_action`) builds # Shallow copy is intentional: the caller (`send_next_action`) builds
@@ -97,11 +119,10 @@ class SyncInferenceEngine(InferenceEngine):
observation, self._device, self._task, self._robot_type observation, self._device, self._task, self._robot_type
) )
observation = self._preprocessor(observation) observation = self._preprocessor(observation)
action = self._policy.select_action(observation) action_chunk = self._policy.predict_action_chunk(observation)
action = self._postprocessor(action) processed_chunk = self._postprocessor(action_chunk)
action_tensor = action.squeeze(0).cpu()
# Reorder to match dataset action ordering so the caller can treat self._enqueue_processed_chunk(processed_chunk)
# the returned tensor uniformly across backends. if not self._processed_action_queue:
action_dict = make_robot_action(action_tensor, self._dataset_features) return None
return torch.tensor([action_dict[k] for k in self._ordered_action_keys]) return self._processed_action_queue.popleft().clone()

View File

@@ -47,8 +47,12 @@ class BaseStrategy(RolloutStrategy):
interpolator = self._interpolator interpolator = self._interpolator
control_interval = interpolator.get_control_interval(cfg.fps) control_interval = interpolator.get_control_interval(cfg.fps)
observation_interval = 1.0 / cfg.fps
start_time = time.perf_counter() start_time = time.perf_counter()
next_observation_time = 0.0
obs = None
obs_processed = None
engine.resume() engine.resume()
logger.info("Base strategy control loop started") logger.info("Base strategy control loop started")
@@ -59,13 +63,18 @@ class BaseStrategy(RolloutStrategy):
logger.info("Duration limit reached (%.0fs)", cfg.duration) logger.info("Duration limit reached (%.0fs)", cfg.duration)
break break
obs = robot.get_observation() if obs is None or loop_start >= next_observation_time:
obs_processed = ctx.processors.robot_observation_processor(obs) obs = robot.get_observation()
engine.notify_observation(obs_processed) obs_processed = ctx.processors.robot_observation_processor(obs)
engine.notify_observation(obs_processed)
next_observation_time = loop_start + observation_interval
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval): if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
continue continue
if obs_processed is None:
continue
action_dict = send_next_action(obs_processed, obs, ctx, interpolator) action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
self._log_telemetry(obs_processed, action_dict, ctx.runtime) self._log_telemetry(obs_processed, action_dict, ctx.runtime)

View File

@@ -13,7 +13,9 @@ Flow under test:
import importlib.util import importlib.util
import sys import sys
from pathlib import Path from pathlib import Path
from types import ModuleType, SimpleNamespace
import numpy as np
import torch import torch
from lerobot.configs.types import ( from lerobot.configs.types import (
@@ -22,7 +24,13 @@ from lerobot.configs.types import (
PolicyFeature, PolicyFeature,
RTCAttentionSchedule, RTCAttentionSchedule,
) )
from lerobot.processor import TransitionKey, batch_to_transition from lerobot.processor import (
PolicyProcessorPipeline,
TransitionKey,
batch_to_transition,
policy_action_to_transition,
transition_to_policy_action,
)
from lerobot.processor.normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep from lerobot.processor.normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep
from lerobot.processor.relative_action_processor import ( from lerobot.processor.relative_action_processor import (
AbsoluteActionsProcessorStep, AbsoluteActionsProcessorStep,
@@ -52,6 +60,34 @@ _rtc_debug_mod = _import_rtc_module("lerobot.policies.rtc.debug_tracker", "debug
_rtc_mod = _import_rtc_module("lerobot.policies.rtc.modeling_rtc", "modeling_rtc.py") _rtc_mod = _import_rtc_module("lerobot.policies.rtc.modeling_rtc", "modeling_rtc.py")
RTCProcessor = _rtc_mod.RTCProcessor RTCProcessor = _rtc_mod.RTCProcessor
def _ensure_rollout_test_packages() -> Path:
rollout_dir = Path(__file__).resolve().parents[3] / "src" / "lerobot" / "rollout"
rollout_pkg = sys.modules.setdefault("lerobot.rollout", ModuleType("lerobot.rollout"))
rollout_pkg.__path__ = [str(rollout_dir)]
inference_pkg = sys.modules.setdefault(
"lerobot.rollout.inference", ModuleType("lerobot.rollout.inference")
)
inference_pkg.__path__ = [str(rollout_dir / "inference")]
return rollout_dir
def _import_rollout_module(module_name: str, relative_path: str):
rollout_dir = _ensure_rollout_test_packages()
spec = importlib.util.spec_from_file_location(module_name, rollout_dir / relative_path)
mod = importlib.util.module_from_spec(spec)
sys.modules[module_name] = mod
spec.loader.exec_module(mod)
return mod
_rollout_robot_wrapper_mod = _import_rollout_module("lerobot.rollout.robot_wrapper", "robot_wrapper.py")
_rollout_base_mod = _import_rollout_module("lerobot.rollout.inference.base", "inference/base.py")
_rollout_sync_mod = _import_rollout_module("lerobot.rollout.inference.sync", "inference/sync.py")
_rollout_rtc_mod = _import_rollout_module("lerobot.rollout.inference.rtc", "inference/rtc.py")
SyncInferenceEngine = _rollout_sync_mod.SyncInferenceEngine
get_current_raw_state = _rollout_rtc_mod._get_current_raw_state
ACTION_DIM = 6 ACTION_DIM = 6
CHUNK_SIZE = 50 CHUNK_SIZE = 50
EXECUTION_HORIZON = 10 EXECUTION_HORIZON = 10
@@ -89,6 +125,44 @@ def _make_relative_pipeline(action_dim=ACTION_DIM, norm_mode=NormalizationMode.M
return relative_step, normalizer, unnormalizer, absolute_step return relative_step, normalizer, unnormalizer, absolute_step
def _make_relative_sync_pipelines(
action_dim=ACTION_DIM,
action_names: list[str] | None = None,
exclude_joints: list[str] | None = None,
):
relative_step = RelativeActionsProcessorStep(
enabled=True,
exclude_joints=exclude_joints or [],
action_names=action_names or [f"joint_{i}.pos" for i in range(action_dim)],
)
absolute_step = AbsoluteActionsProcessorStep(enabled=True, relative_step=relative_step)
preprocessor = PolicyProcessorPipeline(steps=[relative_step], name="test_preprocessor")
postprocessor = PolicyProcessorPipeline(
steps=[absolute_step],
name="test_postprocessor",
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
)
return relative_step, preprocessor, postprocessor
class _ChunkPolicyStub:
def __init__(self, action_dim: int, n_action_steps: int):
self.config = SimpleNamespace(use_amp=False, n_action_steps=n_action_steps)
self._chunk = torch.zeros(1, n_action_steps, action_dim)
self.predict_calls = 0
def reset(self):
return None
def predict_action_chunk(self, batch):
self.predict_calls += 1
return self._chunk.clone()
def select_action(self, batch):
raise AssertionError("SyncInferenceEngine should consume chunk outputs directly")
class TestActionQueueRelativeActions: class TestActionQueueRelativeActions:
"""Verify ActionQueue stores model-space (relative) actions for RTC and absolute for robot.""" """Verify ActionQueue stores model-space (relative) actions for RTC and absolute for robot."""
@@ -120,6 +194,82 @@ class TestActionQueueRelativeActions:
torch.testing.assert_close(first_action, absolute_actions[0]) torch.testing.assert_close(first_action, absolute_actions[0])
class TestRolloutInferenceRelativeActions:
"""Regression tests for rollout inference engines with relative-action policies."""
def test_sync_engine_postprocesses_chunk_before_queueing(self):
"""Queued sync actions must stay anchored to the state from the chunk-producing step."""
_, preprocessor, postprocessor = _make_relative_sync_pipelines(ACTION_DIM)
policy = _ChunkPolicyStub(action_dim=ACTION_DIM, n_action_steps=3)
ordered_action_keys = [f"joint_{i}.pos" for i in range(ACTION_DIM)]
dataset_features = {ACTION: {"names": ordered_action_keys}}
engine = SyncInferenceEngine(
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset_features=dataset_features,
ordered_action_keys=ordered_action_keys,
task="test",
device="cpu",
robot_type="mock",
)
state_1 = np.arange(1, ACTION_DIM + 1, dtype=np.float32)
state_2 = 10 * state_1
action_1 = engine.get_action({OBS_STATE: state_1.copy()})
action_2 = engine.get_action({OBS_STATE: state_2.copy()})
torch.testing.assert_close(action_1, torch.from_numpy(state_1))
torch.testing.assert_close(action_2, torch.from_numpy(state_1))
assert policy.predict_calls == 1
def test_rtc_reanchoring_prefers_raw_cached_state(self):
"""RTC re-anchoring must use the raw state cached before observation normalization."""
action_dim = ACTION_DIM
relative_step = RelativeActionsProcessorStep(
enabled=True,
action_names=[f"joint_{i}.pos" for i in range(action_dim)],
)
features = {
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(action_dim,)),
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,)),
}
stats = {
OBS_STATE: {
"mean": np.arange(1, action_dim + 1, dtype=np.float32),
"std": 2 * np.ones(action_dim, dtype=np.float32),
},
ACTION: {
"mean": np.zeros(action_dim, dtype=np.float32),
"std": np.ones(action_dim, dtype=np.float32),
},
}
preprocessor = PolicyProcessorPipeline(
steps=[
relative_step,
NormalizerProcessorStep(
features=features,
norm_map={
FeatureType.STATE: NormalizationMode.MEAN_STD,
FeatureType.ACTION: NormalizationMode.MEAN_STD,
},
stats=stats,
),
],
name="test_preprocessor_with_state_norm",
)
raw_state = torch.from_numpy(np.arange(5, 5 + action_dim, dtype=np.float32)).unsqueeze(0)
preprocessed = preprocessor({OBS_STATE: raw_state.clone()})
current_state = get_current_raw_state(relative_step, preprocessed.get(OBS_STATE))
torch.testing.assert_close(current_state, raw_state)
assert not torch.allclose(preprocessed[OBS_STATE], raw_state)
class TestRTCDenoiseWithRelativeLeftovers: class TestRTCDenoiseWithRelativeLeftovers:
"""Verify RTC denoise_step correctly handles relative-space prev_chunk_left_over.""" """Verify RTC denoise_step correctly handles relative-space prev_chunk_left_over."""