mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 10:21:24 +00:00
- Updated standard deviation parameterization in SACConfig to 'softplus' with defined min and max values for improved stability. - Modified action sampling in SACPolicy to use reparameterized sampling, ensuring better gradient flow and log probability calculations. - Cleaned up log probability calculations in TanhMultivariateNormalDiag for clarity and efficiency. - Increased evaluation frequency in YAML configuration to 50000 for more efficient training cycles. These changes aim to enhance the robustness and performance of the SAC implementation during training and inference.
90 lines
2.0 KiB
YAML
90 lines
2.0 KiB
YAML
# @package _global_
|
|
|
|
# Train with:
|
|
#
|
|
# python lerobot/scripts/train.py \
|
|
# env=pusht \
|
|
# +dataset=lerobot/pusht_keypoints
|
|
|
|
seed: 1
|
|
dataset_repo_id: lerobot/pusht_keypoints
|
|
|
|
training:
|
|
offline_steps: 0
|
|
|
|
# Offline training dataloader
|
|
num_workers: 4
|
|
|
|
batch_size: 128
|
|
grad_clip_norm: 10.0
|
|
lr: 3e-4
|
|
|
|
eval_freq: 50000
|
|
log_freq: 500
|
|
save_freq: 50000
|
|
|
|
online_steps: 1000000
|
|
online_rollout_n_episodes: 10
|
|
online_rollout_batch_size: 10
|
|
online_steps_between_rollouts: 1000
|
|
online_sampling_ratio: 1.0
|
|
online_env_seed: 10000
|
|
online_buffer_capacity: 40000
|
|
online_buffer_seed_size: 0
|
|
do_online_rollout_async: false
|
|
|
|
delta_timestamps:
|
|
observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
|
observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
|
action: "[i / ${fps} for i in range(${policy.horizon})]"
|
|
next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
|
|
|
|
policy:
|
|
name: sac
|
|
|
|
pretrained_model_path:
|
|
|
|
# Input / output structure.
|
|
n_action_repeats: 1
|
|
horizon: 5
|
|
n_action_steps: 5
|
|
|
|
input_shapes:
|
|
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
|
observation.environment_state: [16]
|
|
observation.state: ["${env.state_dim}"]
|
|
output_shapes:
|
|
action: ["${env.action_dim}"]
|
|
|
|
# Normalization / Unnormalization
|
|
input_normalization_modes:
|
|
observation.environment_state: min_max
|
|
observation.state: min_max
|
|
output_normalization_modes:
|
|
action: min_max
|
|
|
|
# Architecture / modeling.
|
|
# Neural networks.
|
|
# image_encoder_hidden_dim: 32
|
|
discount: 0.99
|
|
temperature_init: 1.0
|
|
num_critics: 2
|
|
num_subsample_critics: None
|
|
critic_lr: 3e-4
|
|
actor_lr: 3e-4
|
|
temperature_lr: 3e-4
|
|
critic_target_update_weight: 0.005
|
|
utd_ratio: 2
|
|
|
|
|
|
# # Loss coefficients.
|
|
# reward_coeff: 0.5
|
|
# expectile_weight: 0.9
|
|
# value_coeff: 0.1
|
|
# consistency_coeff: 20.0
|
|
# advantage_scaling: 3.0
|
|
# pi_coeff: 0.5
|
|
# temporal_decay_coeff: 0.5
|
|
# # Target model.
|
|
# target_model_momentum: 0.995
|