use only rewind loss

This commit is contained in:
Pepijn
2025-08-28 14:22:57 +02:00
parent a4c88d6340
commit c877e98658
11 changed files with 445 additions and 520 deletions

View File

@@ -136,9 +136,12 @@ def train(cfg: TrainPipelineConfig):
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
logging.info("Creating policy")
# Pass episode_data_index for RLearN to calculate proper progress
episode_data_index = dataset.episode_data_index if hasattr(dataset, 'episode_data_index') else None
policy = make_policy(
cfg=cfg.policy,
ds_meta=dataset.meta,
episode_data_index=episode_data_index,
)
preprocessor, postprocessor = make_processor(
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, dataset_stats=dataset.meta.stats