change sampling

This commit is contained in:
Pepijn
2025-08-31 15:20:20 +02:00
parent 4557655ab1
commit def71cc439
3 changed files with 424 additions and 189 deletions

View File

@@ -248,6 +248,15 @@ def train(cfg: TrainPipelineConfig):
drop_n_last_frames=cfg.policy.drop_n_last_frames,
shuffle=True,
)
elif cfg.policy.type == "rlearn":
# For RLearN, drop first 15 frames to avoid padding issues with temporal windows
shuffle = False
sampler = EpisodeAwareSampler(
dataset.episode_data_index,
drop_n_first_frames=15, # Skip frames that would need padding
drop_n_last_frames=0,
shuffle=True,
)
else:
shuffle = True
sampler = None