mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 10:21:24 +00:00
fix
This commit is contained in:
@@ -96,6 +96,10 @@ class RLearNConfig(PreTrainedConfig):
|
||||
categorical_rewards: bool = False
|
||||
reward_bins: int = 10 # only used if categorical_rewards=True
|
||||
|
||||
# Optional: path to episodes.jsonl to build full-episode indices automatically
|
||||
# Default to common dataset layout: <dataset_root>/meta/episodes.jsonl
|
||||
episodes_jsonl_path: str | None = "meta/episodes.jsonl"
|
||||
|
||||
def validate_features(self) -> None:
|
||||
# Require at least one image feature. Language is recommended but optional (can be blank).
|
||||
if not self.image_features:
|
||||
|
||||
@@ -189,6 +189,14 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
self.frame_dropout_p = config.frame_dropout_p
|
||||
self.stride = max(1, config.stride)
|
||||
|
||||
# Auto-load episode_data_index from episodes.jsonl if not provided
|
||||
if self.episode_data_index is None and getattr(config, "episodes_jsonl_path", None):
|
||||
try:
|
||||
self.episode_data_index = self._load_episode_index_from_jsonl(config.episodes_jsonl_path)
|
||||
except Exception:
|
||||
# Defer to runtime error with guidance if loading fails
|
||||
self.episode_data_index = None
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
# Train only projections, temporal module and head by default if backbones are frozen
|
||||
return [p for p in self.parameters() if p.requires_grad]
|
||||
@@ -601,6 +609,31 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
|
||||
return ep, fr
|
||||
|
||||
def _load_episode_index_from_jsonl(self, path: str) -> dict[str, Tensor]:
|
||||
import json
|
||||
lengths: list[int] = []
|
||||
with open(path, "r") as f:
|
||||
for line in f:
|
||||
if not line.strip():
|
||||
continue
|
||||
obj = json.loads(line)
|
||||
# Expect keys: episode_index, length
|
||||
lengths.append(int(obj["length"]))
|
||||
|
||||
# Build cumulative from/to (exclusive)
|
||||
starts = [0]
|
||||
for L in lengths[:-1]:
|
||||
starts.append(starts[-1] + L)
|
||||
ends = []
|
||||
for i, L in enumerate(lengths):
|
||||
ends.append(starts[i] + L)
|
||||
|
||||
device = next(self.parameters()).device
|
||||
return {
|
||||
"from": torch.tensor(starts, device=device, dtype=torch.long),
|
||||
"to": torch.tensor(ends, device=device, dtype=torch.long),
|
||||
}
|
||||
|
||||
|
||||
# Helper functions for ReWiND architecture
|
||||
|
||||
|
||||
Reference in New Issue
Block a user