optimzize data loading

This commit is contained in:
Pepijn
2025-08-30 15:40:36 +02:00
parent 599218fe9a
commit 0b5da92a58
3 changed files with 15 additions and 6 deletions

View File

@@ -53,11 +53,7 @@ def make_rlearn_processor(
input_steps = [
# No renaming by default, but keep for future extensibility
RenameProcessor(rename_map={}),
NormalizerProcessor(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
# Move heavy normalization to GPU after transfer for better parallelism
ToBatchProcessor(),
RLearnLanguageFromTaskProcessor(),
# Use SigLIP2 for tokenizer to keep vocab aligned with text tower
@@ -69,6 +65,12 @@ def make_rlearn_processor(
padding_side="right",
),
DeviceProcessor(device=config.device),
# Move normalization after GPU transfer to use GPU acceleration
NormalizerProcessor(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
]
output_steps = [

View File

@@ -79,9 +79,10 @@ _ Open X-Embodiment (OXE)
- Exactly similar to: https://github.com/lucidrains/rewind-reward-pytorch/blob/main/rewind_reward_pytorch/rewind_reward.py#L11 [x]
- Try DINO v2 as encoder Base 86 M: with https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2 [x]
- Test rewind (evaluate) [x]
- benchmark siglip 2 vs this implementation forward pass, debug speed [x]
- use siglip 2 [x]
- Overfit on one episode []
- Cleanup code? []
- benchmark siglip 2 vs this implementation forward pass, debug speed []
- Convert python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 --repo-id=IPEC-COMMUNITY/bc_z_lerobot and train on 1 percent
- Then on 10 percent
- Ablation dino v2 vs dino v3 base 86 M

View File

@@ -14,12 +14,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import time
from contextlib import nullcontext
from pprint import pformat
from typing import Any
import torch
# Fix tokenizer parallelism conflicts with multiprocessing
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from termcolor import colored
from torch.amp import GradScaler
from torch.optim import Optimizer
@@ -240,6 +244,8 @@ def train(cfg: TrainPipelineConfig):
sampler=sampler,
pin_memory=device.type == "cuda",
drop_last=False,
persistent_workers=cfg.num_workers > 0, # Keep workers alive
prefetch_factor=2 if cfg.num_workers > 0 else None, # Prefetch batches
)
dl_iter = cycle(dataloader)