diff --git a/src/lerobot/policies/rlearn/configuration_rlearn.py b/src/lerobot/policies/rlearn/configuration_rlearn.py index 075dc23dc..1a406b1aa 100644 --- a/src/lerobot/policies/rlearn/configuration_rlearn.py +++ b/src/lerobot/policies/rlearn/configuration_rlearn.py @@ -96,6 +96,10 @@ class RLearNConfig(PreTrainedConfig): categorical_rewards: bool = False reward_bins: int = 10 # only used if categorical_rewards=True + # Optional: path to episodes.jsonl to build full-episode indices automatically + # Default to common dataset layout: /meta/episodes.jsonl + episodes_jsonl_path: str | None = "meta/episodes.jsonl" + def validate_features(self) -> None: # Require at least one image feature. Language is recommended but optional (can be blank). if not self.image_features: diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index c533f7b64..e9979b7ac 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -189,6 +189,14 @@ class RLearNPolicy(PreTrainedPolicy): self.frame_dropout_p = config.frame_dropout_p self.stride = max(1, config.stride) + # Auto-load episode_data_index from episodes.jsonl if not provided + if self.episode_data_index is None and getattr(config, "episodes_jsonl_path", None): + try: + self.episode_data_index = self._load_episode_index_from_jsonl(config.episodes_jsonl_path) + except Exception: + # Defer to runtime error with guidance if loading fails + self.episode_data_index = None + def get_optim_params(self) -> dict: # Train only projections, temporal module and head by default if backbones are frozen return [p for p in self.parameters() if p.requires_grad] @@ -601,6 +609,31 @@ class RLearNPolicy(PreTrainedPolicy): return ep, fr + def _load_episode_index_from_jsonl(self, path: str) -> dict[str, Tensor]: + import json + lengths: list[int] = [] + with open(path, "r") as f: + for line in f: + if not line.strip(): + continue + obj = json.loads(line) + # Expect keys: episode_index, length + lengths.append(int(obj["length"])) + + # Build cumulative from/to (exclusive) + starts = [0] + for L in lengths[:-1]: + starts.append(starts[-1] + L) + ends = [] + for i, L in enumerate(lengths): + ends.append(starts[i] + L) + + device = next(self.parameters()).device + return { + "from": torch.tensor(starts, device=device, dtype=torch.long), + "to": torch.tensor(ends, device=device, dtype=torch.long), + } + # Helper functions for ReWiND architecture