This commit is contained in:
Pepijn
2025-08-30 12:33:39 +02:00
parent 7440d772ff
commit 8ad00d1ee7
2 changed files with 37 additions and 0 deletions

View File

@@ -96,6 +96,10 @@ class RLearNConfig(PreTrainedConfig):
categorical_rewards: bool = False
reward_bins: int = 10 # only used if categorical_rewards=True
# Optional: path to episodes.jsonl to build full-episode indices automatically
# Default to common dataset layout: <dataset_root>/meta/episodes.jsonl
episodes_jsonl_path: str | None = "meta/episodes.jsonl"
def validate_features(self) -> None:
# Require at least one image feature. Language is recommended but optional (can be blank).
if not self.image_features:

View File

@@ -189,6 +189,14 @@ class RLearNPolicy(PreTrainedPolicy):
self.frame_dropout_p = config.frame_dropout_p
self.stride = max(1, config.stride)
# Auto-load episode_data_index from episodes.jsonl if not provided
if self.episode_data_index is None and getattr(config, "episodes_jsonl_path", None):
try:
self.episode_data_index = self._load_episode_index_from_jsonl(config.episodes_jsonl_path)
except Exception:
# Defer to runtime error with guidance if loading fails
self.episode_data_index = None
def get_optim_params(self) -> dict:
# Train only projections, temporal module and head by default if backbones are frozen
return [p for p in self.parameters() if p.requires_grad]
@@ -601,6 +609,31 @@ class RLearNPolicy(PreTrainedPolicy):
return ep, fr
def _load_episode_index_from_jsonl(self, path: str) -> dict[str, Tensor]:
import json
lengths: list[int] = []
with open(path, "r") as f:
for line in f:
if not line.strip():
continue
obj = json.loads(line)
# Expect keys: episode_index, length
lengths.append(int(obj["length"]))
# Build cumulative from/to (exclusive)
starts = [0]
for L in lengths[:-1]:
starts.append(starts[-1] + L)
ends = []
for i, L in enumerate(lengths):
ends.append(starts[i] + L)
device = next(self.parameters()).device
return {
"from": torch.tensor(starts, device=device, dtype=torch.long),
"to": torch.tensor(ends, device=device, dtype=torch.long),
}
# Helper functions for ReWiND architecture