mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 21:01:26 +00:00
chore(docs): update doctrines pipeline files (#1872)
* docs(processor): update docstrings batch_processor * docs(processor): update docstrings device_processor * docs(processor): update docstrings tokenizer_processor * update docstrings processor_act * update docstrings for pipeline_features * update docstrings for utils * update docstring for processor_diffusion * update docstrings factory * add docstrings to pi0 processor * add docstring to pi0fast processor * add docstring classifier processor * add docstring to sac processor * add docstring smolvla processor * add docstring to tdmpc processor * add docstring to vqbet processor * add docstrings to converters * add docstrings for delta_action_processor * add docstring to gym action processor * update hil processor * add docstring to joint obs processor * add docstring to migrate_normalize_processor * update docstrings normalize processor * update docstring normalize processor * update docstrings observation processor * update docstrings rename_processor * add docstrings robot_kinematic_processor * cleanup rl comments * add docstring to train.py * add docstring to teleoperate.py * add docstrings to phone_processor.py * add docstrings to teleop_phone.py * add docstrings to control_utils.py * add docstrings to visualization_utils.py --------- Co-authored-by: Pepijn <pepijn@huggingface.co>
This commit is contained in:
@@ -98,9 +98,7 @@ from lerobot.utils.utils import (
|
||||
|
||||
ACTOR_SHUTDOWN_TIMEOUT = 30
|
||||
|
||||
#################################################
|
||||
# Main entry point #
|
||||
#################################################
|
||||
# Main entry point
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
@@ -207,9 +205,7 @@ def actor_cli(cfg: TrainRLServerPipelineConfig):
|
||||
logging.info("[ACTOR] queues closed")
|
||||
|
||||
|
||||
#################################################
|
||||
# Core algorithm functions #
|
||||
#################################################
|
||||
# Core algorithm functions
|
||||
|
||||
|
||||
def act_with_policy(
|
||||
@@ -406,9 +402,7 @@ def act_with_policy(
|
||||
busy_wait(1 / cfg.env.fps - dt_time)
|
||||
|
||||
|
||||
#################################################
|
||||
# Communication Functions - Group all gRPC/messaging functions #
|
||||
#################################################
|
||||
# Communication Functions - Group all gRPC/messaging functions
|
||||
|
||||
|
||||
def establish_learner_connection(
|
||||
@@ -653,9 +647,7 @@ def interactions_stream(
|
||||
return services_pb2.Empty()
|
||||
|
||||
|
||||
#################################################
|
||||
# Policy functions #
|
||||
#################################################
|
||||
# Policy functions
|
||||
|
||||
|
||||
def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device):
|
||||
@@ -687,9 +679,7 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device)
|
||||
logging.info("[ACTOR] Loaded discrete critic parameters from Learner.")
|
||||
|
||||
|
||||
#################################################
|
||||
# Utilities functions #
|
||||
#################################################
|
||||
# Utilities functions
|
||||
|
||||
|
||||
def push_transitions_to_transport_queue(transitions: list, transitions_queue):
|
||||
|
||||
@@ -103,11 +103,6 @@ from lerobot.utils.wandb_utils import WandBLogger
|
||||
LOG_PREFIX = "[LEARNER]"
|
||||
|
||||
|
||||
#################################################
|
||||
# MAIN ENTRY POINTS AND CORE ALGORITHM FUNCTIONS #
|
||||
#################################################
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def train_cli(cfg: TrainRLServerPipelineConfig):
|
||||
if not use_threads(cfg):
|
||||
@@ -250,9 +245,7 @@ def start_learner_threads(
|
||||
logging.info("[LEARNER] queues closed")
|
||||
|
||||
|
||||
#################################################
|
||||
# Core algorithm functions #
|
||||
#################################################
|
||||
# Core algorithm functions
|
||||
|
||||
|
||||
def add_actor_information_and_train(
|
||||
@@ -820,9 +813,7 @@ def make_optimizers_and_scheduler(cfg: TrainRLServerPipelineConfig, policy: nn.M
|
||||
return optimizers, lr_scheduler
|
||||
|
||||
|
||||
#################################################
|
||||
# Training setup functions #
|
||||
#################################################
|
||||
# Training setup functions
|
||||
|
||||
|
||||
def handle_resume_logic(cfg: TrainRLServerPipelineConfig) -> TrainRLServerPipelineConfig:
|
||||
@@ -1023,9 +1014,7 @@ def initialize_offline_replay_buffer(
|
||||
return offline_replay_buffer
|
||||
|
||||
|
||||
#################################################
|
||||
# Utilities/Helpers functions #
|
||||
#################################################
|
||||
# Utilities/Helpers functions
|
||||
|
||||
|
||||
def get_observation_features(
|
||||
|
||||
@@ -65,6 +65,28 @@ def update_policy(
|
||||
use_amp: bool = False,
|
||||
lock=None,
|
||||
) -> tuple[MetricsTracker, dict]:
|
||||
"""
|
||||
Performs a single training step to update the policy's weights.
|
||||
|
||||
This function executes the forward and backward passes, clips gradients, and steps the optimizer and
|
||||
learning rate scheduler. It also handles mixed-precision training via a GradScaler.
|
||||
|
||||
Args:
|
||||
train_metrics: A MetricsTracker instance to record training statistics.
|
||||
policy: The policy model to be trained.
|
||||
batch: A batch of training data.
|
||||
optimizer: The optimizer used to update the policy's parameters.
|
||||
grad_clip_norm: The maximum norm for gradient clipping.
|
||||
grad_scaler: The GradScaler for automatic mixed-precision training.
|
||||
lr_scheduler: An optional learning rate scheduler.
|
||||
use_amp: A boolean indicating whether to use automatic mixed precision.
|
||||
lock: An optional lock for thread-safe optimizer updates.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- The updated MetricsTracker with new statistics for this step.
|
||||
- A dictionary of outputs from the policy's forward pass, for logging purposes.
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
device = get_device_from_parameters(policy)
|
||||
policy.train()
|
||||
@@ -108,6 +130,20 @@ def update_policy(
|
||||
|
||||
@parser.wrap()
|
||||
def train(cfg: TrainPipelineConfig):
|
||||
"""
|
||||
Main function to train a policy.
|
||||
|
||||
This function orchestrates the entire training pipeline, including:
|
||||
- Setting up logging, seeding, and device configuration.
|
||||
- Creating the dataset, evaluation environment (if applicable), policy, and optimizer.
|
||||
- Handling resumption from a checkpoint.
|
||||
- Running the main training loop, which involves fetching data batches and calling `update_policy`.
|
||||
- Periodically logging metrics, saving model checkpoints, and evaluating the policy.
|
||||
- Pushing the final trained model to the Hugging Face Hub if configured.
|
||||
|
||||
Args:
|
||||
cfg: A `TrainPipelineConfig` object containing all training configurations.
|
||||
"""
|
||||
cfg.validate()
|
||||
logging.info(pformat(cfg.to_dict()))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user