Optimize dataset updates by incrementally concatenating new data instead of reloading from disk, reducing memory usage and improving performance.

This commit is contained in:
Michel Aractingi
2025-09-05 18:37:48 +02:00
parent 992fb177c3
commit 0747afdba7

View File

@@ -28,7 +28,7 @@ import pandas as pd
import PIL.Image
import torch
import torch.utils
from datasets import Dataset
from datasets import Dataset, concatenate_datasets
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.constants import REPOCARD_NAME
from huggingface_hub.errors import RevisionNotFoundError
@@ -316,17 +316,16 @@ class LeRobotDatasetMetadata:
path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(path, index=False)
# Update the Hugging Face dataset by reloading it.
# This process should be fast because only the latest Parquet file has been modified.
# Therefore, only this file needs to be converted to PyArrow; the rest is loaded from the PyArrow memory-mapped cache.
# Update the Hugging Face dataset incrementally instead of reloading from disk
# This eliminates repeated load_episodes calls that cause cache bloat
if self.episodes is None:
self.episodes = load_episodes(self.root)
return
# Explicitly delete old dataset to free memory before reloading
if hasattr(self, "episodes") and self.episodes is not None:
del self.episodes
self.episodes = None
gc.collect()
self.episodes = load_episodes(self.root)
# Remove columns from df that start with 'stats/'
df = df.drop(columns=[col for col in df.columns if col.startswith("stats/")])
new_episode_dataset = Dataset.from_pandas(df)
self.episodes = concatenate_datasets([self.episodes, new_episode_dataset])
def save_episode(
self,
@@ -1064,17 +1063,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
else:
df.to_parquet(path)
# Update the Hugging Face dataset by reloading it.
# This process should be fast because only the latest Parquet file has been modified.
# Therefore, only this file needs to be converted to PyArrow; the rest is loaded from the PyArrow memory-mapped cache.
# Explicitly delete old dataset to free memory before reloading
if hasattr(self, "hf_dataset") and self.hf_dataset is not None:
del self.hf_dataset
self.hf_dataset = None
gc.collect()
self.hf_dataset = self.load_hf_dataset()
new_hf_dataset = Dataset.from_pandas(df)
self.hf_dataset = concatenate_datasets([self.hf_dataset, new_hf_dataset])
metadata = {
"data/chunk_index": chunk_idx,
@@ -1093,7 +1083,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
if self.meta.episodes is None:
# Initialize indices for a new dataset made of the first episode data
chunk_idx, file_idx = 0, 0
latest_duration_in_s = 0
latest_duration_in_s = 0.0
new_path = self.root / self.meta.video_path.format(
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
)
@@ -1119,6 +1109,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
)
new_path.parent.mkdir(parents=True, exist_ok=True)
shutil.move(str(ep_path), str(new_path))
latest_duration_in_s = 0.0
else:
# Update latest video file
concat_video_files([latest_path, ep_path], self.root, video_key, chunk_idx, file_idx)