feat(processor): enhance type safety with generic DataProcessorPipeline for policy and robot pipelines (#1915)

* refactor(processor): enhance type annotations for processors in record, replay, teleoperate, and control utils

- Updated type annotations for preprocessor and postprocessor parameters in record_loop and predict_action functions to specify the expected dictionary types.
- Adjusted robot_action_processor type in ReplayConfig and TeleoperateConfig to improve clarity and maintainability.
- Ensured consistency in type definitions across multiple files, enhancing overall code readability.

* refactor(processor): enhance type annotations for RobotProcessorPipeline in various files

- Updated type annotations for RobotProcessorPipeline instances in evaluate.py, record.py, replay.py, teleoperate.py, and other related files to specify input and output types more clearly.
- Introduced new type conversions for PolicyAction and EnvTransition to improve type safety and maintainability across the processing pipelines.
- Ensured consistency in type definitions, enhancing overall code readability and reducing potential runtime errors.

* refactor(processor): update transition handling in processors to use transition_to_batch

- Replaced direct transition handling with transition_to_batch in various processor tests and implementations to ensure consistent batching of input data.
- Updated assertions in tests to reflect changes in data structure, enhancing clarity and maintainability.
- Improved overall code readability by standardizing the way transitions are processed across different processor types.

* refactor(tests): standardize transition key usage in processor tests

- Updated assertions in processor test files to utilize the TransitionKey for action references, enhancing consistency across tests.
- Replaced direct string references with TransitionKey constants for improved readability and maintainability.
- Ensured that all relevant tests reflect these changes, contributing to a more uniform approach in handling transitions.
This commit is contained in:
Adil Zouitine
2025-09-11 13:36:04 +02:00
committed by GitHub
parent a2489ab0da
commit 376a6457cf
29 changed files with 671 additions and 786 deletions

View File

@@ -72,7 +72,7 @@ from lerobot.envs.factory import make_env
from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor.core import TransitionKey
from lerobot.processor.core import PolicyAction
from lerobot.processor.pipeline import PolicyProcessorPipeline
from lerobot.utils.io_utils import write_video
from lerobot.utils.random_utils import set_seed
@@ -86,8 +86,8 @@ from lerobot.utils.utils import (
def rollout(
env: gym.vector.VectorEnv,
policy: PreTrainedPolicy,
preprocessor: PolicyProcessorPipeline[dict[str, Any]],
postprocessor: PolicyProcessorPipeline[dict[str, Any]],
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
seeds: list[int] | None = None,
return_observations: bool = False,
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
@@ -159,15 +159,15 @@ def rollout(
observation = add_envs_task(env, observation)
observation = preprocessor(observation)
with torch.inference_mode():
action = policy.select_action(observation)
action: torch.Tensor = postprocessor({TransitionKey.ACTION: action})[TransitionKey.ACTION]
action: PolicyAction = policy.select_action(observation)
action: PolicyAction = postprocessor(action)
# Convert to CPU / numpy.
action: np.ndarray = action.to("cpu").numpy()
assert action.ndim == 2, "Action dimensions should be (batch, action_dim)"
action_numpy: np.ndarray = action.to("cpu").numpy()
assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)"
# Apply the next action.
observation, reward, terminated, truncated, info = env.step(action)
observation, reward, terminated, truncated, info = env.step(action_numpy)
if render_callback is not None:
render_callback(env)
@@ -181,7 +181,7 @@ def rollout(
# Keep track of which environments are done so far.
done = terminated | truncated | done
all_actions.append(torch.from_numpy(action))
all_actions.append(torch.from_numpy(action_numpy))
all_rewards.append(torch.from_numpy(reward))
all_dones.append(torch.from_numpy(done))
all_successes.append(torch.tensor(successes))
@@ -220,8 +220,8 @@ def rollout(
def eval_policy(
env: gym.vector.VectorEnv,
policy: PreTrainedPolicy,
preprocessor: PolicyProcessorPipeline[dict[str, Any]],
postprocessor: PolicyProcessorPipeline[dict[str, Any]],
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
n_episodes: int,
max_episodes_rendered: int = 0,
videos_dir: Path | None = None,