Compare commits

..

1 Commits

Author SHA1 Message Date
CarolinePascal
5231ba53dc fix(loading errors): improving dataset loading errors handling and logging 2026-04-09 17:43:42 +02:00
6 changed files with 14 additions and 11 deletions

View File

@@ -35,7 +35,7 @@ class DatasetConfig:
revision: str | None = None
use_imagenet_stats: bool = True
video_backend: str = field(default_factory=get_safe_default_codec)
streaming: bool = True
streaming: bool = False
def __post_init__(self) -> None:
if self.episodes is not None:

View File

@@ -39,7 +39,7 @@ class EvalPipelineConfig:
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
# Explicit consent to execute remote code from the Hub (required for hub environments).
trust_remote_code: bool = True
trust_remote_code: bool = False
def __post_init__(self) -> None:
# HACK: We parse again the cli args here to get the pretrained path if there was one.

View File

@@ -62,16 +62,16 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
device: str | None = None # e.g. "cuda", "cuda:0", "cpu", or "mps"
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: bool = True
use_amp: bool = False
# Whether the policy employed PEFT for training.
use_peft: bool = True
use_peft: bool = False
push_to_hub: bool = True # type: ignore[assignment] # TODO: use a different name to avoid override
repo_id: str | None = None
# Upload on private repository on the Hugging Face hub.
private: bool | None = True
private: bool | None = None
# Add tags to your policy on the hub.
tags: list[str] | None = None
# Add tags to your policy on the hub.

View File

@@ -46,13 +46,13 @@ class TrainPipelineConfig(HubMixin):
# `dir` is the directory of an existing run with at least one checkpoint in it.
# Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
# regardless of what's provided with the training command at the time of resumption.
resume: bool = True
resume: bool = False
# `seed` is used for training (eg: model initialization, dataset shuffling)
# AND for the evaluation environments.
seed: int | None = 1000
# Set to True to use deterministic cuDNN algorithms for reproducibility.
# This disables cudnn.benchmark and may reduce training speed by ~10-20 percent.
cudnn_deterministic: bool = True
cudnn_deterministic: bool = False
# Number of workers for the dataloader.
num_workers: int = 4
batch_size: int = 8
@@ -60,10 +60,10 @@ class TrainPipelineConfig(HubMixin):
eval_freq: int = 20_000
log_freq: int = 200
tolerance_s: float = 1e-4
save_checkpoint: bool = False
save_checkpoint: bool = True
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
save_freq: int = 20_000
use_policy_training_preset: bool = False
use_policy_training_preset: bool = True
optimizer: OptimizerConfig | None = None
scheduler: LRSchedulerConfig | None = None
eval: EvalConfig = field(default_factory=EvalConfig)

View File

@@ -87,7 +87,7 @@ class DatasetReader:
"""Attempt to load from local cache. Returns True if data is sufficient."""
try:
self.hf_dataset = self._load_hf_dataset()
except (FileNotFoundError, NotADirectoryError):
except (FileNotFoundError, NotADirectoryError, ValueError):
self.hf_dataset = None
return False
if not self._check_cached_episodes_sufficient():

View File

@@ -78,7 +78,10 @@ def load_nested_dataset(
with SuppressProgressBars():
# We use .from_parquet() memory-mapped loading for efficiency
filters = pa_ds.field("episode_index").isin(episodes) if episodes is not None else None
return Dataset.from_parquet([str(path) for path in paths], filters=filters, features=features)
try:
return Dataset.from_parquet([str(path) for path in paths], filters=filters, features=features)
except ValueError:
raise ValueError(f"Failed to load parquet files in {pq_dir}, make sure the dataset is valid and is not missing any files.")
def get_parquet_num_frames(parquet_path: str | Path) -> int: