mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-01 11:21:27 +00:00
chore(docs): update doctrines pipeline files (#1872)
* docs(processor): update docstrings batch_processor * docs(processor): update docstrings device_processor * docs(processor): update docstrings tokenizer_processor * update docstrings processor_act * update docstrings for pipeline_features * update docstrings for utils * update docstring for processor_diffusion * update docstrings factory * add docstrings to pi0 processor * add docstring to pi0fast processor * add docstring classifier processor * add docstring to sac processor * add docstring smolvla processor * add docstring to tdmpc processor * add docstring to vqbet processor * add docstrings to converters * add docstrings for delta_action_processor * add docstring to gym action processor * update hil processor * add docstring to joint obs processor * add docstring to migrate_normalize_processor * update docstrings normalize processor * update docstring normalize processor * update docstrings observation processor * update docstrings rename_processor * add docstrings robot_kinematic_processor * cleanup rl comments * add docstring to train.py * add docstring to teleoperate.py * add docstrings to phone_processor.py * add docstrings to teleop_phone.py * add docstrings to control_utils.py * add docstrings to visualization_utils.py --------- Co-authored-by: Pepijn <pepijn@huggingface.co>
This commit is contained in:
@@ -65,6 +65,28 @@ def update_policy(
|
||||
use_amp: bool = False,
|
||||
lock=None,
|
||||
) -> tuple[MetricsTracker, dict]:
|
||||
"""
|
||||
Performs a single training step to update the policy's weights.
|
||||
|
||||
This function executes the forward and backward passes, clips gradients, and steps the optimizer and
|
||||
learning rate scheduler. It also handles mixed-precision training via a GradScaler.
|
||||
|
||||
Args:
|
||||
train_metrics: A MetricsTracker instance to record training statistics.
|
||||
policy: The policy model to be trained.
|
||||
batch: A batch of training data.
|
||||
optimizer: The optimizer used to update the policy's parameters.
|
||||
grad_clip_norm: The maximum norm for gradient clipping.
|
||||
grad_scaler: The GradScaler for automatic mixed-precision training.
|
||||
lr_scheduler: An optional learning rate scheduler.
|
||||
use_amp: A boolean indicating whether to use automatic mixed precision.
|
||||
lock: An optional lock for thread-safe optimizer updates.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- The updated MetricsTracker with new statistics for this step.
|
||||
- A dictionary of outputs from the policy's forward pass, for logging purposes.
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
device = get_device_from_parameters(policy)
|
||||
policy.train()
|
||||
@@ -108,6 +130,20 @@ def update_policy(
|
||||
|
||||
@parser.wrap()
|
||||
def train(cfg: TrainPipelineConfig):
|
||||
"""
|
||||
Main function to train a policy.
|
||||
|
||||
This function orchestrates the entire training pipeline, including:
|
||||
- Setting up logging, seeding, and device configuration.
|
||||
- Creating the dataset, evaluation environment (if applicable), policy, and optimizer.
|
||||
- Handling resumption from a checkpoint.
|
||||
- Running the main training loop, which involves fetching data batches and calling `update_policy`.
|
||||
- Periodically logging metrics, saving model checkpoints, and evaluating the policy.
|
||||
- Pushing the final trained model to the Hugging Face Hub if configured.
|
||||
|
||||
Args:
|
||||
cfg: A `TrainPipelineConfig` object containing all training configurations.
|
||||
"""
|
||||
cfg.validate()
|
||||
logging.info(pformat(cfg.to_dict()))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user