[HIL-SERL] Review feedback modifications (#1112)

This commit is contained in:
Adil Zouitine
2025-05-15 15:24:41 +02:00
committed by GitHub
parent c7a3973653
commit 2051dd38fc
17 changed files with 504 additions and 180 deletions

View File

@@ -14,6 +14,66 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Learner server runner for distributed HILSerl robot policy training.
This script implements the learner component of the distributed HILSerl architecture.
It initializes the policy network, maintains replay buffers, and updates
the policy based on transitions received from the actor server.
Examples of usage:
- Start a learner server for training:
```bash
python lerobot/scripts/server/learner_server.py --config_path lerobot/configs/train_config_hilserl_so100.json
```
- Run with specific SAC hyperparameters:
```bash
python lerobot/scripts/server/learner_server.py \
--config_path lerobot/configs/train_config_hilserl_so100.json \
--learner.sac.alpha=0.1 \
--learner.sac.gamma=0.99
```
- Run with a specific dataset and wandb logging:
```bash
python lerobot/scripts/server/learner_server.py \
--config_path lerobot/configs/train_config_hilserl_so100.json \
--dataset.repo_id=username/pick_lift_cube \
--wandb.enable=true \
--wandb.project=hilserl_training
```
- Run with a pretrained policy for fine-tuning:
```bash
python lerobot/scripts/server/learner_server.py \
--config_path lerobot/configs/train_config_hilserl_so100.json \
--pretrained_policy_name_or_path=outputs/previous_training/checkpoints/080000/pretrained_model
```
- Run with a reward classifier model:
```bash
python lerobot/scripts/server/learner_server.py \
--config_path lerobot/configs/train_config_hilserl_so100.json \
--reward_classifier_pretrained_path=outputs/reward_model/best_model
```
**NOTE**: Start the learner server before launching the actor server. The learner opens a gRPC server
to communicate with actors.
**NOTE**: Training progress can be monitored through Weights & Biases if wandb.enable is set to true
in your configuration.
**WORKFLOW**:
1. Create training configuration with proper policy, dataset, and environment settings
2. Start this learner server with the configuration
3. Start an actor server with the same configuration
4. Monitor training progress through wandb dashboard
For more details on the complete HILSerl training workflow, see:
https://github.com/michel-aractingi/lerobot-hilserl-guide
"""
import logging
import os
@@ -73,7 +133,6 @@ from lerobot.scripts.server.utils import (
LOG_PREFIX = "[LEARNER]"
logging.basicConfig(level=logging.INFO)
#################################################
# MAIN ENTRY POINTS AND CORE ALGORITHM FUNCTIONS #
@@ -113,13 +172,17 @@ def train(cfg: TrainPipelineConfig, job_name: str | None = None):
if job_name is None:
raise ValueError("Job name must be specified either in config or as a parameter")
display_pid = False
if not use_threads(cfg):
display_pid = True
# Create logs directory to ensure it exists
log_dir = os.path.join(cfg.output_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"learner_{job_name}.log")
# Initialize logging with explicit log file
init_logging(log_file=log_file)
init_logging(log_file=log_file, display_pid=display_pid)
logging.info(f"Learner logging initialized, writing to {log_file}")
logging.info(pformat(cfg.to_dict()))
@@ -275,7 +338,7 @@ def add_actor_information_and_train(
log_dir = os.path.join(cfg.output_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"learner_train_process_{os.getpid()}.log")
init_logging(log_file=log_file)
init_logging(log_file=log_file, display_pid=True)
logging.info("Initialized logging for actor information and training process")
logging.info("Initializing policy")
@@ -604,7 +667,7 @@ def start_learner_server(
log_file = os.path.join(log_dir, f"learner_server_process_{os.getpid()}.log")
# Initialize logging with explicit log file
init_logging(log_file=log_file)
init_logging(log_file=log_file, display_pid=True)
logging.info("Learner server process logging initialized")
# Setup process handlers to handle shutdown signal