diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 465cbf531..04d43d91e 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -380,10 +380,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): "dataloading_s": AverageMeter("data_s", ":.3f"), } - # Use effective batch size for proper epoch calculation in distributed training + # Keep global batch size for logging; MetricsTracker handles world size internally. effective_batch_size = cfg.batch_size * accelerator.num_processes train_tracker = MetricsTracker( - effective_batch_size, + cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, diff --git a/src/lerobot/utils/logging_utils.py b/src/lerobot/utils/logging_utils.py index c4c1f42e0..1497c0585 100644 --- a/src/lerobot/utils/logging_utils.py +++ b/src/lerobot/utils/logging_utils.py @@ -104,9 +104,10 @@ class MetricsTracker: self.metrics = metrics self.steps = initial_step + world_size = accelerator.num_processes if accelerator else 1 # A sample is an (observation,action) pair, where observation and action # can be on multiple timestamps. In a batch, we have `batch_size` number of samples. - self.samples = self.steps * self._batch_size + self.samples = self.steps * self._batch_size * world_size self.episodes = self.samples / self._avg_samples_per_ep self.epochs = self.samples / self._num_frames self.accelerator = accelerator @@ -132,7 +133,8 @@ class MetricsTracker: Updates metrics that depend on 'step' for one step. """ self.steps += 1 - self.samples += self._batch_size * (self.accelerator.num_processes if self.accelerator else 1) + world_size = self.accelerator.num_processes if self.accelerator else 1 + self.samples += self._batch_size * world_size self.episodes = self.samples / self._avg_samples_per_ep self.epochs = self.samples / self._num_frames diff --git a/tests/utils/test_logging_utils.py b/tests/utils/test_logging_utils.py index 560ba5701..1207534c0 100644 --- a/tests/utils/test_logging_utils.py +++ b/tests/utils/test_logging_utils.py @@ -24,6 +24,11 @@ def mock_metrics(): return {"loss": AverageMeter("loss", ":.3f"), "accuracy": AverageMeter("accuracy", ":.2f")} +class MockAccelerator: + def __init__(self, num_processes: int): + self.num_processes = num_processes + + def test_average_meter_initialization(): meter = AverageMeter("loss", ":.2f") assert meter.name == "loss" @@ -82,6 +87,37 @@ def test_metrics_tracker_step(mock_metrics): assert tracker.epochs == tracker.samples / 1000 +def test_metrics_tracker_initialization_with_accelerator(mock_metrics): + tracker = MetricsTracker( + batch_size=32, + num_frames=1000, + num_episodes=50, + metrics=mock_metrics, + initial_step=10, + accelerator=MockAccelerator(num_processes=2), + ) + assert tracker.steps == 10 + assert tracker.samples == 10 * 32 * 2 + assert tracker.episodes == tracker.samples / (1000 / 50) + assert tracker.epochs == tracker.samples / 1000 + + +def test_metrics_tracker_step_with_accelerator(mock_metrics): + tracker = MetricsTracker( + batch_size=32, + num_frames=1000, + num_episodes=50, + metrics=mock_metrics, + initial_step=5, + accelerator=MockAccelerator(num_processes=2), + ) + tracker.step() + assert tracker.steps == 6 + assert tracker.samples == (5 * 32 * 2) + (32 * 2) + assert tracker.episodes == tracker.samples / (1000 / 50) + assert tracker.epochs == tracker.samples / 1000 + + def test_metrics_tracker_getattr(mock_metrics): tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics) assert tracker.loss == mock_metrics["loss"]