From 841d54c050269d3d548c418feeb1eafaa34600fc Mon Sep 17 00:00:00 2001 From: Pepijn Date: Tue, 11 Mar 2025 12:23:51 +0100 Subject: [PATCH] Use sampler always (temp fix) --- lerobot/scripts/train.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 52fed33de..1ad16a153 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -25,7 +25,7 @@ from torch.amp import GradScaler from torch.optim import Optimizer from lerobot.common.datasets.factory import make_dataset -from lerobot.common.datasets.sampler import PrioritizedSampler +from lerobot.common.datasets.sampler import EpisodeAwareSampler, PrioritizedSampler from lerobot.common.datasets.utils import cycle from lerobot.common.envs.factory import make_env from lerobot.common.optim.factory import make_optimizer_and_scheduler @@ -166,18 +166,26 @@ def train(cfg: TrainPipelineConfig): # create dataloader for offline training if hasattr(cfg.policy, "drop_n_last_frames"): shuffle = False - sampler = PrioritizedSampler( - data_len=data_len, - alpha=0.6, - beta=0.1, - eps=1e-6, - replacement=True, - num_samples_per_epoch=data_len, + sampler = EpisodeAwareSampler( + dataset.episode_data_index, + drop_n_last_frames=cfg.policy.drop_n_last_frames, + shuffle=True, ) else: shuffle = True sampler = None + # TODO(pepijn): If experiment works integrate this + shuffle = False + sampler = PrioritizedSampler( + data_len=data_len, + alpha=0.6, + beta=0.1, + eps=1e-6, + replacement=True, + num_samples_per_epoch=data_len, + ) + dataloader = torch.utils.data.DataLoader( dataset, num_workers=cfg.num_workers,