mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
Compare commits
3 Commits
codex/roll
...
codex/fix-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
727ca1a92c | ||
|
|
ee737b72d0 | ||
|
|
cf2e42f557 |
@@ -24,9 +24,8 @@ from copy import copy
|
||||
import torch
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import prepare_observation_for_inference
|
||||
from lerobot.processor import PolicyProcessorPipeline, RelativeActionsProcessorStep
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.policies.utils import make_robot_action, prepare_observation_for_inference
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
|
||||
from .base import InferenceEngine
|
||||
|
||||
@@ -62,17 +61,6 @@ class SyncInferenceEngine(InferenceEngine):
|
||||
self._robot_type = robot_type
|
||||
self._processed_action_queue: deque[torch.Tensor] = deque()
|
||||
|
||||
self._relative_step = next(
|
||||
(s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled),
|
||||
None,
|
||||
)
|
||||
if self._relative_step is not None and self._relative_step.action_names is None:
|
||||
cfg_names = getattr(policy.config, "action_feature_names", None)
|
||||
action_names = cfg_names or dataset_features.get(ACTION, {}).get("names")
|
||||
if action_names:
|
||||
self._relative_step.action_names = list(action_names)
|
||||
logger.info("Relative actions enabled: sync chunks will be postprocessed before queueing")
|
||||
|
||||
logger.info(
|
||||
"SyncInferenceEngine initialized (device=%s, action_keys=%d)",
|
||||
self._device,
|
||||
@@ -96,7 +84,7 @@ class SyncInferenceEngine(InferenceEngine):
|
||||
self._processed_action_queue.clear()
|
||||
|
||||
def _enqueue_processed_chunk(self, action_chunk: torch.Tensor) -> None:
|
||||
"""Queue postprocessed per-step actions in policy output order."""
|
||||
"""Convert a postprocessed action chunk into ordered per-step CPU tensors."""
|
||||
if action_chunk.ndim == 2:
|
||||
action_chunk = action_chunk.unsqueeze(0)
|
||||
|
||||
@@ -104,13 +92,12 @@ class SyncInferenceEngine(InferenceEngine):
|
||||
action_chunk = action_chunk[:, : min(n_action_steps, action_chunk.shape[1])]
|
||||
|
||||
for action in action_chunk.squeeze(0):
|
||||
action_tensor = action.detach().cpu()
|
||||
if len(action_tensor) != len(self._ordered_action_keys):
|
||||
raise ValueError(
|
||||
f"Action tensor length ({len(action_tensor)}) != action keys "
|
||||
f"({len(self._ordered_action_keys)})"
|
||||
)
|
||||
self._processed_action_queue.append(action_tensor)
|
||||
action_tensor = action.cpu()
|
||||
action_dict = make_robot_action(action_tensor, self._dataset_features)
|
||||
ordered_action = torch.tensor(
|
||||
[action_dict[k] for k in self._ordered_action_keys], dtype=action_tensor.dtype
|
||||
)
|
||||
self._processed_action_queue.append(ordered_action)
|
||||
|
||||
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
|
||||
"""Run the full inference pipeline on ``obs_frame`` and return an action tensor."""
|
||||
|
||||
@@ -47,8 +47,12 @@ class BaseStrategy(RolloutStrategy):
|
||||
interpolator = self._interpolator
|
||||
|
||||
control_interval = interpolator.get_control_interval(cfg.fps)
|
||||
observation_interval = 1.0 / cfg.fps
|
||||
|
||||
start_time = time.perf_counter()
|
||||
next_observation_time = 0.0
|
||||
obs = None
|
||||
obs_processed = None
|
||||
engine.resume()
|
||||
logger.info("Base strategy control loop started")
|
||||
|
||||
@@ -59,13 +63,18 @@ class BaseStrategy(RolloutStrategy):
|
||||
logger.info("Duration limit reached (%.0fs)", cfg.duration)
|
||||
break
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_processed = ctx.processors.robot_observation_processor(obs)
|
||||
engine.notify_observation(obs_processed)
|
||||
if obs is None or loop_start >= next_observation_time:
|
||||
obs = robot.get_observation()
|
||||
obs_processed = ctx.processors.robot_observation_processor(obs)
|
||||
engine.notify_observation(obs_processed)
|
||||
next_observation_time = loop_start + observation_interval
|
||||
|
||||
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
|
||||
continue
|
||||
|
||||
if obs_processed is None:
|
||||
continue
|
||||
|
||||
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
|
||||
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
|
||||
|
||||
|
||||
@@ -225,66 +225,6 @@ class TestRolloutInferenceRelativeActions:
|
||||
torch.testing.assert_close(action_2, torch.from_numpy(state_1))
|
||||
assert policy.predict_calls == 1
|
||||
|
||||
def test_sync_engine_restores_action_names_for_relative_exclusions(self):
|
||||
"""Serialized processors may omit action names; sync rollout must still honor gripper exclusions."""
|
||||
action_names = [f"joint_{i}.pos" for i in range(ACTION_DIM - 1)] + ["gripper.pos"]
|
||||
relative_step, preprocessor, postprocessor = _make_relative_sync_pipelines(
|
||||
ACTION_DIM,
|
||||
action_names=action_names,
|
||||
exclude_joints=["gripper"],
|
||||
)
|
||||
relative_step.action_names = None
|
||||
policy = _ChunkPolicyStub(action_dim=ACTION_DIM, n_action_steps=1)
|
||||
policy.config.action_feature_names = action_names
|
||||
dataset_features = {ACTION: {"names": action_names}}
|
||||
|
||||
assert relative_step.action_names is None
|
||||
|
||||
engine = SyncInferenceEngine(
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset_features=dataset_features,
|
||||
ordered_action_keys=action_names,
|
||||
task="test",
|
||||
device="cpu",
|
||||
robot_type="mock",
|
||||
)
|
||||
|
||||
state = np.arange(1, ACTION_DIM + 1, dtype=np.float32)
|
||||
action = engine.get_action({OBS_STATE: state.copy()})
|
||||
expected = torch.from_numpy(state.copy())
|
||||
expected[-1] = 0.0
|
||||
|
||||
torch.testing.assert_close(action, expected)
|
||||
assert relative_step.action_names == action_names
|
||||
|
||||
def test_sync_engine_does_not_remap_chunk_through_dataset_action_names(self):
|
||||
"""Postprocessed chunks are already in policy order; dataset feature order must not scramble them."""
|
||||
action_names = [f"joint_{i}.pos" for i in range(ACTION_DIM)]
|
||||
_, preprocessor, postprocessor = _make_relative_sync_pipelines(
|
||||
ACTION_DIM,
|
||||
action_names=action_names,
|
||||
)
|
||||
policy = _ChunkPolicyStub(action_dim=ACTION_DIM, n_action_steps=1)
|
||||
policy.config.action_feature_names = action_names
|
||||
|
||||
engine = SyncInferenceEngine(
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset_features={ACTION: {"names": list(reversed(action_names))}},
|
||||
ordered_action_keys=action_names,
|
||||
task="test",
|
||||
device="cpu",
|
||||
robot_type="mock",
|
||||
)
|
||||
|
||||
state = np.arange(1, ACTION_DIM + 1, dtype=np.float32)
|
||||
action = engine.get_action({OBS_STATE: state.copy()})
|
||||
|
||||
torch.testing.assert_close(action, torch.from_numpy(state))
|
||||
|
||||
def test_rtc_reanchoring_prefers_raw_cached_state(self):
|
||||
"""RTC re-anchoring must use the raw state cached before observation normalization."""
|
||||
action_dim = ACTION_DIM
|
||||
|
||||
Reference in New Issue
Block a user