mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 10:51:35 +00:00
Compare commits
9 Commits
feat/dart
...
tmp/fold_t
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a07ff640bb | ||
|
|
0753415244 | ||
|
|
a054663e38 | ||
|
|
aceb651e40 | ||
|
|
76a4529d29 | ||
|
|
39e14c086c | ||
|
|
0af2029328 | ||
|
|
c027b2971c | ||
|
|
9cc203034e |
480
examples/port_datasets/slurm_mirror_dataset.py
Normal file
480
examples/port_datasets/slurm_mirror_dataset.py
Normal file
@@ -0,0 +1,480 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Mirror a bimanual robot dataset using SLURM for distributed video processing.
|
||||
|
||||
This script creates a mirrored version of a dataset where:
|
||||
1. Left and right arm observations/actions are swapped
|
||||
2. Joint values are inverted according to a mirroring mask
|
||||
3. Video frames are horizontally flipped (parallelized via SLURM)
|
||||
|
||||
Example usage:
|
||||
```shell
|
||||
# SLURM execution
|
||||
python examples/port_datasets/slurm_mirror_dataset.py \
|
||||
--repo-id pepijn/openarm_bimanual \
|
||||
--output-repo-id pepijn/openarm_bimanual_mirrored \
|
||||
--logs-dir /fsx/user/logs \
|
||||
--partition hopper-cpu
|
||||
|
||||
# Local execution (for debugging)
|
||||
python examples/port_datasets/slurm_mirror_dataset.py \
|
||||
--repo-id pepijn/openarm_bimanual \
|
||||
--output-repo-id pepijn/openarm_bimanual_mirrored \
|
||||
--slurm 0 \
|
||||
--push-to-hub
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from datatrove.executor import LocalPipelineExecutor
|
||||
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||
from datatrove.pipeline.base import PipelineStep
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENARM_MIRRORING_MASK = {
|
||||
"joint_1": -1,
|
||||
"joint_2": -1,
|
||||
"joint_3": -1,
|
||||
"joint_4": 1,
|
||||
"joint_5": -1,
|
||||
"joint_6": -1,
|
||||
"joint_7": -1,
|
||||
"gripper": 1,
|
||||
}
|
||||
|
||||
|
||||
class MirrorVideos(PipelineStep):
|
||||
"""Pipeline step that mirrors video files for assigned episodes."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
output_repo_id: str,
|
||||
root: str | None = None,
|
||||
output_root: str | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self.output_repo_id = output_repo_id
|
||||
self.root = root
|
||||
self.output_root = output_root
|
||||
self.vcodec = vcodec
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
import logging
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
from datasets.utils.tqdm import disable_progress_bars
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
init_logging()
|
||||
disable_progress_bars()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def swap_left_right_name(name: str) -> str:
|
||||
result = name.replace("left_", "LEFT_PLACEHOLDER_")
|
||||
result = result.replace("right_", "left_")
|
||||
result = result.replace("LEFT_PLACEHOLDER_", "right_")
|
||||
return result
|
||||
|
||||
def flip_video_frames(input_path: Path, output_path: Path, fps: float, vcodec: str):
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
cmd = [
|
||||
"ffmpeg", "-y", "-i", str(input_path),
|
||||
"-vf", "hflip",
|
||||
"-c:v", vcodec,
|
||||
"-g", "2",
|
||||
"-crf", "30",
|
||||
"-r", str(int(fps)),
|
||||
"-pix_fmt", "yuv420p",
|
||||
"-loglevel", "error",
|
||||
]
|
||||
if vcodec == "libsvtav1":
|
||||
cmd.extend(["-preset", "12"])
|
||||
cmd.append(str(output_path))
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"FFmpeg failed: {result.stderr}")
|
||||
|
||||
def video_is_valid(path: Path) -> bool:
|
||||
if not path.exists():
|
||||
return False
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["ffprobe", "-v", "error", "-select_streams", "v:0",
|
||||
"-show_entries", "stream=nb_frames", "-of", "csv=p=0", str(path)],
|
||||
capture_output=True, text=True, timeout=30
|
||||
)
|
||||
return result.returncode == 0 and result.stdout.strip().isdigit()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
root = Path(self.root) if self.root else None
|
||||
output_root = Path(self.output_root) if self.output_root else None
|
||||
|
||||
dataset = LeRobotDataset(self.repo_id, root=root)
|
||||
output_root = output_root or (HF_LEROBOT_HOME / self.output_repo_id)
|
||||
|
||||
if not dataset.meta.video_keys:
|
||||
logger.info(f"Rank {rank}: No videos to process")
|
||||
return
|
||||
|
||||
video_tasks = []
|
||||
for old_video_key in dataset.meta.video_keys:
|
||||
new_video_key = swap_left_right_name(old_video_key)
|
||||
for ep_idx in range(dataset.meta.total_episodes):
|
||||
try:
|
||||
src_path = dataset.root / dataset.meta.get_video_file_path(ep_idx, old_video_key)
|
||||
dst_relative = dataset.meta.get_video_file_path(ep_idx, old_video_key)
|
||||
dst_relative_str = str(dst_relative).replace(old_video_key, new_video_key)
|
||||
dst_path = output_root / dst_relative_str
|
||||
if src_path.exists():
|
||||
video_tasks.append((src_path, dst_path, ep_idx, old_video_key))
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
my_tasks = [t for i, t in enumerate(video_tasks) if i % world_size == rank]
|
||||
logger.info(f"Rank {rank}/{world_size}: Processing {len(my_tasks)}/{len(video_tasks)} videos")
|
||||
|
||||
for src_path, dst_path, ep_idx, video_key in my_tasks:
|
||||
if video_is_valid(dst_path):
|
||||
logger.info(f"Rank {rank}: Skipping {dst_path.name} (already done)")
|
||||
continue
|
||||
logger.info(f"Rank {rank}: Processing {src_path.name} -> {dst_path.name}")
|
||||
flip_video_frames(src_path, dst_path, dataset.meta.fps, self.vcodec)
|
||||
|
||||
|
||||
class MirrorDataAndMetadata(PipelineStep):
|
||||
"""Pipeline step that mirrors parquet data and metadata (runs once on rank 0)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
output_repo_id: str,
|
||||
root: str | None = None,
|
||||
output_root: str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self.output_repo_id = output_repo_id
|
||||
self.root = root
|
||||
self.output_root = output_root
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
if rank != 0:
|
||||
return
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datasets.utils.tqdm import disable_progress_bars
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import DATA_DIR, DEFAULT_DATA_PATH, write_info, write_stats, write_tasks
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
init_logging()
|
||||
disable_progress_bars()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MIRRORING_MASK = {
|
||||
"joint_1": -1, "joint_2": -1, "joint_3": -1, "joint_4": 1,
|
||||
"joint_5": -1, "joint_6": -1, "joint_7": -1, "gripper": 1,
|
||||
}
|
||||
|
||||
def get_mirroring_mask(robot_type: str) -> dict[str, int]:
|
||||
if robot_type in ["bi_openarm_follower", "openarm_follower", "bi_openarms_follower", "openarms_follower"]:
|
||||
return MIRRORING_MASK
|
||||
raise ValueError(f"Unknown robot type: {robot_type}. Add a mirroring mask for this robot.")
|
||||
|
||||
def swap_left_right_name(name: str) -> str:
|
||||
result = name.replace("left_", "LEFT_PLACEHOLDER_")
|
||||
result = result.replace("right_", "left_")
|
||||
result = result.replace("LEFT_PLACEHOLDER_", "right_")
|
||||
return result
|
||||
|
||||
def mirror_feature_names(names: list[str]) -> tuple[list[str], dict[int, int]]:
|
||||
mirrored_names = [swap_left_right_name(n) for n in names]
|
||||
old_to_new_idx = {}
|
||||
for old_idx, old_name in enumerate(names):
|
||||
new_name = swap_left_right_name(old_name)
|
||||
new_idx = mirrored_names.index(new_name)
|
||||
old_to_new_idx[old_idx] = new_idx
|
||||
return mirrored_names, old_to_new_idx
|
||||
|
||||
def apply_mirroring_mask(value: float, feature_name: str, mirroring_mask: dict[str, int]) -> float:
|
||||
name_without_prefix = feature_name.split("_", 1)[1] if "_" in feature_name else feature_name
|
||||
joint_name = name_without_prefix.split(".")[0]
|
||||
if joint_name in mirroring_mask:
|
||||
return value * mirroring_mask[joint_name]
|
||||
return value
|
||||
|
||||
def mirror_array(array: np.ndarray, names: list[str], mirroring_mask: dict[str, int]) -> np.ndarray:
|
||||
mirrored_names, idx_mapping = mirror_feature_names(names)
|
||||
result = np.zeros_like(array)
|
||||
for old_idx, new_idx in idx_mapping.items():
|
||||
new_name = mirrored_names[new_idx]
|
||||
value = array[old_idx]
|
||||
mirrored_value = apply_mirroring_mask(value, new_name, mirroring_mask)
|
||||
result[new_idx] = mirrored_value
|
||||
return result
|
||||
|
||||
def mirror_stats(stats: dict) -> dict:
|
||||
mirrored = {}
|
||||
for key, value in stats.items():
|
||||
new_key = swap_left_right_name(key)
|
||||
if isinstance(value, dict):
|
||||
mirrored[new_key] = mirror_stats(value)
|
||||
else:
|
||||
mirrored[new_key] = value
|
||||
return mirrored
|
||||
|
||||
import shutil
|
||||
|
||||
root = Path(self.root) if self.root else None
|
||||
output_root = Path(self.output_root) if self.output_root else None
|
||||
|
||||
dataset = LeRobotDataset(self.repo_id, root=root)
|
||||
output_root = output_root or (HF_LEROBOT_HOME / self.output_repo_id)
|
||||
|
||||
done_marker = output_root / ".data_mirrored"
|
||||
if done_marker.exists():
|
||||
logger.info("Data and metadata already mirrored, skipping")
|
||||
return
|
||||
|
||||
# Clean up partial output from previous failed runs
|
||||
if output_root.exists():
|
||||
logger.info(f"Removing existing partial output: {output_root}")
|
||||
shutil.rmtree(output_root)
|
||||
|
||||
robot_type = dataset.meta.robot_type or "bi_openarms_follower"
|
||||
mirroring_mask = get_mirroring_mask(robot_type)
|
||||
|
||||
mirrored_features = {}
|
||||
for key, feat in dataset.meta.features.items():
|
||||
new_key = swap_left_right_name(key)
|
||||
new_feat = feat.copy()
|
||||
if "names" in new_feat and new_feat["names"]:
|
||||
new_feat["names"] = [swap_left_right_name(n) for n in new_feat["names"]]
|
||||
mirrored_features[new_key] = new_feat
|
||||
|
||||
new_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=self.output_repo_id,
|
||||
fps=dataset.meta.fps,
|
||||
features=mirrored_features,
|
||||
robot_type=dataset.meta.robot_type,
|
||||
root=output_root,
|
||||
use_videos=len(dataset.meta.video_keys) > 0,
|
||||
)
|
||||
|
||||
if dataset.meta.tasks is not None:
|
||||
write_tasks(dataset.meta.tasks, new_meta.root)
|
||||
|
||||
data_dir = dataset.root / DATA_DIR
|
||||
parquet_files = sorted(data_dir.glob("*/*.parquet"))
|
||||
action_names = dataset.meta.features.get("action", {}).get("names", [])
|
||||
state_names = dataset.meta.features.get("observation.state", {}).get("names", [])
|
||||
|
||||
for src_path in parquet_files:
|
||||
df = pd.read_parquet(src_path).reset_index(drop=True)
|
||||
relative_path = src_path.relative_to(dataset.root)
|
||||
chunk_dir = relative_path.parts[1]
|
||||
file_name = relative_path.parts[2]
|
||||
chunk_idx = int(chunk_dir.split("-")[1])
|
||||
file_idx = int(file_name.split("-")[1].split(".")[0])
|
||||
|
||||
if "action" in df.columns and action_names:
|
||||
actions = np.stack(df["action"].values)
|
||||
mirrored_actions = np.array([mirror_array(row, action_names, mirroring_mask) for row in actions])
|
||||
df["action"] = list(mirrored_actions)
|
||||
|
||||
if "observation.state" in df.columns and state_names:
|
||||
states = np.stack(df["observation.state"].values)
|
||||
mirrored_states = np.array([mirror_array(row, state_names, mirroring_mask) for row in states])
|
||||
df["observation.state"] = list(mirrored_states)
|
||||
|
||||
dst_path = new_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(dst_path, index=False)
|
||||
|
||||
episodes_dir = dataset.root / "meta/episodes"
|
||||
dst_episodes_dir = new_meta.root / "meta/episodes"
|
||||
if episodes_dir.exists():
|
||||
dst_episodes_dir.mkdir(parents=True, exist_ok=True)
|
||||
for src_parquet in episodes_dir.glob("*/*.parquet"):
|
||||
df = pd.read_parquet(src_parquet)
|
||||
columns_to_rename = {}
|
||||
for col in df.columns:
|
||||
if col.startswith("videos/"):
|
||||
parts = col.split("/")
|
||||
if len(parts) >= 2:
|
||||
video_key = parts[1]
|
||||
new_video_key = swap_left_right_name(video_key)
|
||||
new_col = col.replace(f"videos/{video_key}/", f"videos/{new_video_key}/")
|
||||
columns_to_rename[col] = new_col
|
||||
if columns_to_rename:
|
||||
df = df.rename(columns=columns_to_rename)
|
||||
dst_parquet = dst_episodes_dir / src_parquet.relative_to(episodes_dir)
|
||||
dst_parquet.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(dst_parquet, index=False)
|
||||
|
||||
new_meta.info.update({
|
||||
"total_episodes": dataset.meta.info["total_episodes"],
|
||||
"total_frames": dataset.meta.info["total_frames"],
|
||||
"total_tasks": dataset.meta.info["total_tasks"],
|
||||
"splits": dataset.meta.info.get("splits", {}),
|
||||
})
|
||||
write_info(new_meta.info, new_meta.root)
|
||||
|
||||
if dataset.meta.stats is not None:
|
||||
mirrored_stats = mirror_stats(dataset.meta.stats)
|
||||
write_stats(mirrored_stats, new_meta.root)
|
||||
|
||||
done_marker.touch()
|
||||
logger.info(f"Data and metadata mirrored to {output_root}")
|
||||
|
||||
|
||||
def swap_left_right_name(name: str) -> str:
|
||||
result = name.replace("left_", "LEFT_PLACEHOLDER_")
|
||||
result = result.replace("right_", "left_")
|
||||
result = result.replace("LEFT_PLACEHOLDER_", "right_")
|
||||
return result
|
||||
|
||||
|
||||
def get_num_video_tasks(repo_id: str, root: str | None = None) -> int:
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
root_path = Path(root) if root else None
|
||||
dataset = LeRobotDataset(repo_id, root=root_path)
|
||||
count = 0
|
||||
for video_key in dataset.meta.video_keys:
|
||||
for ep_idx in range(dataset.meta.total_episodes):
|
||||
try:
|
||||
src_path = dataset.root / dataset.meta.get_video_file_path(ep_idx, video_key)
|
||||
if src_path.exists():
|
||||
count += 1
|
||||
except KeyError:
|
||||
continue
|
||||
return count
|
||||
|
||||
|
||||
def make_mirror_executor(
|
||||
repo_id: str,
|
||||
output_repo_id: str,
|
||||
root: str | None,
|
||||
output_root: str | None,
|
||||
vcodec: str,
|
||||
job_name: str,
|
||||
logs_dir: Path,
|
||||
workers: int,
|
||||
partition: str,
|
||||
cpus_per_task: int,
|
||||
mem_per_cpu: str,
|
||||
time_limit: str,
|
||||
slurm: bool = True,
|
||||
):
|
||||
num_tasks = get_num_video_tasks(repo_id, root) if slurm else 1
|
||||
num_tasks = max(1, num_tasks)
|
||||
|
||||
kwargs = {
|
||||
"pipeline": [
|
||||
MirrorDataAndMetadata(repo_id, output_repo_id, root, output_root),
|
||||
MirrorVideos(repo_id, output_repo_id, root, output_root, vcodec),
|
||||
],
|
||||
"logging_dir": str(logs_dir / job_name),
|
||||
}
|
||||
|
||||
if slurm:
|
||||
kwargs.update({
|
||||
"job_name": job_name,
|
||||
"tasks": num_tasks,
|
||||
"workers": min(workers, num_tasks),
|
||||
"time": time_limit,
|
||||
"partition": partition,
|
||||
"cpus_per_task": cpus_per_task,
|
||||
"sbatch_args": {
|
||||
"mem-per-cpu": mem_per_cpu,
|
||||
"requeue": True,
|
||||
"signal": "USR1@30",
|
||||
},
|
||||
})
|
||||
return SlurmPipelineExecutor(**kwargs)
|
||||
else:
|
||||
kwargs.update({"tasks": 1, "workers": 1})
|
||||
return LocalPipelineExecutor(**kwargs)
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
parser = argparse.ArgumentParser(description="Mirror a bimanual robot dataset using SLURM")
|
||||
parser.add_argument("--repo-id", type=str, required=True, help="Source dataset repo_id")
|
||||
parser.add_argument("--output-repo-id", type=str, required=True, help="Output dataset repo_id")
|
||||
parser.add_argument("--root", type=str, default=None, help="Source dataset root directory")
|
||||
parser.add_argument("--output-root", type=str, default=None, help="Output dataset root directory")
|
||||
parser.add_argument("--vcodec", type=str, default="libsvtav1", help="Video codec")
|
||||
parser.add_argument("--logs-dir", type=Path, default=Path("logs"), help="Directory for datatrove logs")
|
||||
parser.add_argument("--job-name", type=str, default="mirror_dataset", help="SLURM job name")
|
||||
parser.add_argument("--slurm", type=int, default=1, help="Use SLURM (1) or local (0)")
|
||||
parser.add_argument("--workers", type=int, default=64, help="Number of SLURM workers")
|
||||
parser.add_argument("--partition", type=str, default="hopper-cpu", help="SLURM partition")
|
||||
parser.add_argument("--cpus-per-task", type=int, default=4, help="CPUs per task")
|
||||
parser.add_argument("--mem-per-cpu", type=str, default="2G", help="Memory per CPU")
|
||||
parser.add_argument("--time-limit", type=str, default="04:00:00", help="SLURM time limit")
|
||||
parser.add_argument("--push-to-hub", action="store_true", help="Push mirrored dataset to HuggingFace Hub")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
executor = make_mirror_executor(
|
||||
repo_id=args.repo_id,
|
||||
output_repo_id=args.output_repo_id,
|
||||
root=args.root,
|
||||
output_root=args.output_root,
|
||||
vcodec=args.vcodec,
|
||||
job_name=args.job_name,
|
||||
logs_dir=args.logs_dir,
|
||||
workers=args.workers,
|
||||
partition=args.partition,
|
||||
cpus_per_task=args.cpus_per_task,
|
||||
mem_per_cpu=args.mem_per_cpu,
|
||||
time_limit=args.time_limit,
|
||||
slurm=args.slurm == 1,
|
||||
)
|
||||
executor.run()
|
||||
|
||||
if args.push_to_hub:
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
output_root = Path(args.output_root) if args.output_root else HF_LEROBOT_HOME / args.output_repo_id
|
||||
logger.info(f"Pushing dataset to HuggingFace Hub: {args.output_repo_id}")
|
||||
dataset = LeRobotDataset(args.output_repo_id, root=output_root)
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -45,12 +45,12 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
Args:
|
||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||
current step and additional steps going back).
|
||||
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
|
||||
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
|
||||
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
||||
input_normalization_modes: A dictionary with key representing the modality and the value specifies the
|
||||
normalization mode to apply.
|
||||
output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to
|
||||
the original scale.
|
||||
"""
|
||||
|
||||
n_obs_steps: int = 1
|
||||
|
||||
@@ -72,10 +72,11 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
||||
raise ValueError(
|
||||
f"Same robot_type is expected, but got robot_type={meta.robot_type} instead of {robot_type}."
|
||||
)
|
||||
if features != meta.features:
|
||||
raise ValueError(
|
||||
f"Same features is expected, but got features={meta.features} instead of {features}."
|
||||
)
|
||||
# TODO: Temporarily disabled for merging datasets with different features (e.g. shirt_id)
|
||||
# if features != meta.features:
|
||||
# raise ValueError(
|
||||
# f"Same features is expected, but got features={meta.features} instead of {features}."
|
||||
# )
|
||||
|
||||
return fps, robot_type, features
|
||||
|
||||
|
||||
@@ -563,7 +563,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
episodes: list[int] | None = None,
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[str, list[float]] | None = None,
|
||||
tolerance_s: float = 1e-4,
|
||||
tolerance_s: float = 1e-2,
|
||||
revision: str | None = None,
|
||||
force_cache_sync: bool = False,
|
||||
download_videos: bool = True,
|
||||
@@ -1572,7 +1572,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
root: str | Path | None = None,
|
||||
robot_type: str | None = None,
|
||||
use_videos: bool = True,
|
||||
tolerance_s: float = 1e-4,
|
||||
tolerance_s: float = 1e-2,
|
||||
image_writer_processes: int = 0,
|
||||
image_writer_threads: int = 0,
|
||||
video_backend: str | None = None,
|
||||
|
||||
@@ -205,7 +205,6 @@ class ObservationConfig:
|
||||
|
||||
add_joint_velocity_to_observation: bool = False
|
||||
add_current_to_observation: bool = False
|
||||
add_ee_pose_to_observation: bool = False
|
||||
display_cameras: bool = False
|
||||
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ class ACTConfig(PreTrainedConfig):
|
||||
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_features` and `output_features`.
|
||||
Those are: `input_shapes` and 'output_shapes`.
|
||||
|
||||
Notes on the inputs and outputs:
|
||||
- Either:
|
||||
@@ -48,12 +48,21 @@ class ACTConfig(PreTrainedConfig):
|
||||
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
|
||||
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
|
||||
environment, and throws the other 50 out.
|
||||
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
|
||||
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
|
||||
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||
include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
[-1, 1] range.
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets.
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
|
||||
`None` means no pretrained weights.
|
||||
|
||||
@@ -30,7 +30,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_features` and `output_features`.
|
||||
Those are: `input_shapes` and `output_shapes`.
|
||||
|
||||
Notes on the inputs and outputs:
|
||||
- "observation.state" is required as an input key.
|
||||
@@ -48,12 +48,21 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
||||
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
||||
See `DiffusionPolicy.select_action` for more details.
|
||||
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
|
||||
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
|
||||
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||
include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
[-1, 1] range.
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets.
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
||||
within the image size. If None, no cropping is done.
|
||||
@@ -64,7 +73,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
||||
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
|
||||
use_separate_rgb_encoder_per_camera: Whether to use a separate RGB encoder for each camera view.
|
||||
use_separate_rgb_encoders_per_camera: Whether to use a separate RGB encoder for each camera view.
|
||||
down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet.
|
||||
You may provide a variable number of dimensions, therefore also controlling the degree of
|
||||
downsampling.
|
||||
|
||||
@@ -61,8 +61,6 @@ class PI05Config(PreTrainedConfig):
|
||||
# Add empty images. Used to add empty cameras when no image features are present.
|
||||
empty_cameras: int = 0
|
||||
|
||||
tokenizer_max_length: int = 200 # see openpi `__post_init__`
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
|
||||
@@ -30,7 +30,7 @@ class TDMPCConfig(PreTrainedConfig):
|
||||
camera observations.
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_features`, `output_features`, and perhaps `max_random_shift_ratio`.
|
||||
Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift_ratio`.
|
||||
|
||||
Args:
|
||||
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
|
||||
@@ -40,12 +40,24 @@ class TDMPCConfig(PreTrainedConfig):
|
||||
is an alternative to using action repeats. If this is set to more than 1, then we require
|
||||
`n_action_repeats == 1`, `use_mpc == True` and `n_action_steps <= horizon`. Note that this
|
||||
approach of using multiple steps from the plan is not in the original implementation.
|
||||
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
|
||||
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
|
||||
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||
include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
[-1, 1] range. Note that here this defaults to None meaning inputs are not normalized. This is to
|
||||
match the original implementation.
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets. NOTE: Clipping
|
||||
to [-1, +1] is used during MPPI/CEM. Therefore, it is recommended that you stick with "min_max"
|
||||
normalization mode here.
|
||||
image_encoder_hidden_dim: Number of channels for the convolutional layers used for image encoding.
|
||||
state_encoder_hidden_dim: Hidden dimension for MLP used for state vector encoding.
|
||||
latent_dim: Observation's latent embedding dimension.
|
||||
|
||||
@@ -32,7 +32,7 @@ class VQBeTConfig(PreTrainedConfig):
|
||||
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_features` and `output_features`.
|
||||
Those are: `input_shapes` and `output_shapes`.
|
||||
|
||||
Notes on the inputs and outputs:
|
||||
- "observation.state" is required as an input key.
|
||||
@@ -46,12 +46,21 @@ class VQBeTConfig(PreTrainedConfig):
|
||||
current step and additional steps going back).
|
||||
n_action_pred_token: Total number of current token and future tokens that VQ-BeT predicts.
|
||||
action_chunk_size: Action chunk size of each action prediction token.
|
||||
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
|
||||
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
|
||||
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
||||
The key represents the input data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "observation.image" refers to an input from
|
||||
a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
|
||||
Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
||||
The key represents the output data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
||||
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
[-1, 1] range.
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets.
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
||||
within the image size. If None, no cropping is done.
|
||||
|
||||
@@ -314,7 +314,7 @@ class TimeLimitProcessorStep(TruncatedProcessorStep):
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||
class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
|
||||
"""
|
||||
Applies a penalty for inefficient gripper usage.
|
||||
|
||||
@@ -329,27 +329,26 @@ class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
penalty: float = -0.01
|
||||
max_gripper_pos: float = 30.0
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
def complementary_data(self, complementary_data: dict) -> dict:
|
||||
"""
|
||||
Calculates the gripper penalty and adds it to the complementary data.
|
||||
|
||||
Args:
|
||||
transition: The incoming environment transition.
|
||||
complementary_data: The incoming complementary data, which should contain
|
||||
raw joint positions.
|
||||
|
||||
Returns:
|
||||
The modified transition with the penalty added to complementary data.
|
||||
A new complementary data dictionary with the `discrete_penalty` key added.
|
||||
"""
|
||||
new_transition = transition.copy()
|
||||
action = new_transition.get(TransitionKey.ACTION)
|
||||
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
action = self.transition.get(TransitionKey.ACTION)
|
||||
|
||||
raw_joint_positions = complementary_data.get("raw_joint_positions")
|
||||
if raw_joint_positions is None:
|
||||
return new_transition
|
||||
return complementary_data
|
||||
|
||||
current_gripper_pos = raw_joint_positions.get(GRIPPER_KEY, None)
|
||||
if current_gripper_pos is None:
|
||||
return new_transition
|
||||
return complementary_data
|
||||
|
||||
# Gripper action is a PolicyAction at this stage
|
||||
gripper_action = action[-1].item()
|
||||
@@ -365,12 +364,11 @@ class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
|
||||
gripper_penalty = self.penalty * int(gripper_penalty_bool)
|
||||
|
||||
# Update complementary data with penalty info
|
||||
# Create new complementary data with penalty info
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data[DISCRETE_PENALTY_KEY] = gripper_penalty
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
|
||||
return new_transition
|
||||
return new_complementary_data
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
|
||||
@@ -412,10 +412,7 @@ def make_processors(
|
||||
if cfg.processor.observation.add_current_to_observation:
|
||||
env_pipeline_steps.append(MotorCurrentProcessorStep(robot=env.robot))
|
||||
|
||||
add_ee_pose = (
|
||||
cfg.processor.observation is not None and cfg.processor.observation.add_ee_pose_to_observation
|
||||
)
|
||||
if kinematics_solver is not None and add_ee_pose:
|
||||
if kinematics_solver is not None:
|
||||
env_pipeline_steps.append(
|
||||
ForwardKinematicsJointsToEEObservation(
|
||||
kinematics=kinematics_solver,
|
||||
@@ -438,12 +435,7 @@ def make_processors(
|
||||
)
|
||||
|
||||
# Add gripper penalty processor if gripper config exists and enabled
|
||||
# Only add if max_gripper_pos is explicitly configured (required for normalization)
|
||||
if (
|
||||
cfg.processor.gripper is not None
|
||||
and cfg.processor.gripper.use_gripper
|
||||
and cfg.processor.max_gripper_pos is not None
|
||||
):
|
||||
if cfg.processor.gripper is not None and cfg.processor.gripper.use_gripper:
|
||||
env_pipeline_steps.append(
|
||||
GripperPenaltyProcessorStep(
|
||||
penalty=cfg.processor.gripper.gripper_penalty,
|
||||
|
||||
@@ -26,21 +26,8 @@ from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.utils.constants import PRETRAINED_MODEL_DIR
|
||||
|
||||
|
||||
def cfg_to_group(
|
||||
cfg: TrainPipelineConfig, return_list: bool = False, truncate_tags: bool = False, max_tag_length: int = 64
|
||||
) -> list[str] | str:
|
||||
def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str:
|
||||
"""Return a group name for logging. Optionally returns group name as list."""
|
||||
|
||||
def _maybe_truncate(tag: str) -> str:
|
||||
"""Truncate tag to max_tag_length characters if required.
|
||||
|
||||
wandb rejects tags longer than 64 characters.
|
||||
See: https://github.com/wandb/wandb/blob/main/wandb/sdk/wandb_settings.py
|
||||
"""
|
||||
if len(tag) <= max_tag_length:
|
||||
return tag
|
||||
return tag[:max_tag_length]
|
||||
|
||||
lst = [
|
||||
f"policy:{cfg.policy.type}",
|
||||
f"seed:{cfg.seed}",
|
||||
@@ -49,8 +36,6 @@ def cfg_to_group(
|
||||
lst.append(f"dataset:{cfg.dataset.repo_id}")
|
||||
if cfg.env is not None:
|
||||
lst.append(f"env:{cfg.env.type}")
|
||||
if truncate_tags:
|
||||
lst = [_maybe_truncate(tag) for tag in lst]
|
||||
return lst if return_list else "-".join(lst)
|
||||
|
||||
|
||||
@@ -98,7 +83,7 @@ class WandBLogger:
|
||||
entity=self.cfg.entity,
|
||||
name=self.job_name,
|
||||
notes=self.cfg.notes,
|
||||
tags=cfg_to_group(cfg, return_list=True, truncate_tags=True),
|
||||
tags=cfg_to_group(cfg, return_list=True),
|
||||
dir=self.log_dir,
|
||||
config=cfg.to_dict(),
|
||||
# TODO(rcadene): try set to True
|
||||
|
||||
@@ -184,9 +184,6 @@ class DatasetRecordConfig:
|
||||
vcodec: str = "libsvtav1"
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
# Expert noise injection scale. Noise is added to robot actions but not recorded in dataset.
|
||||
# This forces recovery behavior for more robust learned policies. 0.0 means no noise. #https://arxiv.org/pdf/1703.09327, https://arxiv.org/abs/2507.09061
|
||||
noise_scale: float = 0.0
|
||||
|
||||
def __post_init__(self):
|
||||
if self.single_task is None:
|
||||
@@ -286,7 +283,6 @@ def record_loop(
|
||||
single_task: str | None = None,
|
||||
display_data: bool = False,
|
||||
display_compressed_images: bool = False,
|
||||
noise_scale: float = 0.0,
|
||||
):
|
||||
if dataset is not None and dataset.fps != fps:
|
||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset.fps} != {fps}).")
|
||||
@@ -384,27 +380,18 @@ def record_loop(
|
||||
action_values = act_processed_teleop
|
||||
robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
|
||||
|
||||
# Write clean action to dataset (before noise injection)
|
||||
if dataset is not None:
|
||||
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
|
||||
frame = {**observation_frame, **action_frame, "task": single_task}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
# Expert noise injection: add noise to motor commands but not to recorded labels
|
||||
if noise_scale > 0:
|
||||
import torch
|
||||
|
||||
for key in robot_action_to_send:
|
||||
if isinstance(robot_action_to_send[key], torch.Tensor):
|
||||
noise = torch.randn_like(robot_action_to_send[key]) * noise_scale
|
||||
robot_action_to_send[key] = robot_action_to_send[key] + noise
|
||||
|
||||
# Send action to robot
|
||||
# Action can eventually be clipped using `max_relative_target`,
|
||||
# so action actually sent is saved in the dataset. action = postprocessor.process(action)
|
||||
# TODO(steven, pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot.
|
||||
_sent_action = robot.send_action(robot_action_to_send)
|
||||
|
||||
# Write to dataset
|
||||
if dataset is not None:
|
||||
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
|
||||
frame = {**observation_frame, **action_frame, "task": single_task}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
if display_data:
|
||||
log_rerun_data(
|
||||
observation=obs_processed, action=action_values, compress_images=display_compressed_images
|
||||
@@ -523,7 +510,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
single_task=cfg.dataset.single_task,
|
||||
display_data=cfg.display_data,
|
||||
display_compressed_images=display_compressed_images,
|
||||
noise_scale=cfg.dataset.noise_scale,
|
||||
)
|
||||
|
||||
# Execute a few seconds without recording to give time to manually reset the environment
|
||||
|
||||
@@ -337,13 +337,28 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
# create dataloader for offline training
|
||||
if hasattr(cfg.policy, "drop_n_last_frames"):
|
||||
# Filter out episodes - hardcoded list of bad episodes to discard
|
||||
episodes_to_discard = {
|
||||
133, 134, 502, 565, 568, 657, 910, 944, 1039, 1209, 1346, 1360, 1379,
|
||||
1605, 1690, 1790, 2105, 2106, 2122, 2118, 2156, 2575, 2764, 2876, 2925,
|
||||
3100, 3381, 3405, 3406, 68, 1214, 1456,
|
||||
}
|
||||
all_episodes = set(range(dataset.meta.total_episodes))
|
||||
episodes_to_use = dataset.episodes # May be None (all episodes) or a subset
|
||||
# If dataset.episodes is already filtered, start from that subset
|
||||
if episodes_to_use is not None:
|
||||
episodes_to_use = [ep for ep in episodes_to_use if ep not in episodes_to_discard]
|
||||
else:
|
||||
episodes_to_use = sorted(all_episodes - episodes_to_discard)
|
||||
|
||||
if hasattr(cfg.policy, "drop_n_last_frames") or episodes_to_use is not None:
|
||||
shuffle = False
|
||||
drop_n_last = getattr(cfg.policy, "drop_n_last_frames", 0)
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.meta.episodes["dataset_from_index"],
|
||||
dataset.meta.episodes["dataset_to_index"],
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
||||
episode_indices_to_use=episodes_to_use,
|
||||
drop_n_last_frames=drop_n_last,
|
||||
shuffle=True,
|
||||
)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user