Add option to use the fused optim version of ADamW

This commit is contained in:
Michel Aractingi
2025-12-12 16:31:03 +00:00
parent ce348a3460
commit a669049da2
8 changed files with 23 additions and 5 deletions

View File

@@ -234,6 +234,8 @@ def merge_datasets(
datasets: list[LeRobotDataset],
output_repo_id: str,
output_dir: str | Path | None = None,
data_files_size_in_mb: float | None = None,
video_files_size_in_mb: float | None = None,
) -> LeRobotDataset:
"""Merge multiple LeRobotDatasets into a single dataset.
@@ -257,6 +259,8 @@ def merge_datasets(
aggr_repo_id=output_repo_id,
roots=roots,
aggr_root=output_dir,
data_files_size_in_mb=data_files_size_in_mb,
video_files_size_in_mb=video_files_size_in_mb,
)
merged_dataset = LeRobotDataset(
@@ -747,11 +751,11 @@ def _copy_and_reindex_videos(
f"videos/{video_key}/to_timestamp"
]
else:
# Build list of time ranges to keep, in sorted order.
sorted_keep_episodes = sorted(episodes_in_file, key=lambda x: episode_mapping[x])
# Episodes are already in order by old episode index (from episode_mapping iteration),
# which equals from_timestamp order since episodes are created sequentially.
episodes_to_keep_ranges: list[tuple[float, float]] = []
for old_idx in sorted_keep_episodes:
for old_idx in episodes_in_file:
src_ep = src_dataset.meta.episodes[old_idx]
from_ts = src_ep[f"videos/{video_key}/from_timestamp"]
to_ts = src_ep[f"videos/{video_key}/to_timestamp"]
@@ -781,7 +785,7 @@ def _copy_and_reindex_videos(
)
cumulative_ts = 0.0
for old_idx in sorted_keep_episodes:
for old_idx in episodes_in_file:
new_idx = episode_mapping[old_idx]
src_ep = src_dataset.meta.episodes[old_idx]
ep_length = src_ep["length"]

View File

@@ -557,7 +557,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
episodes: list[int] | None = None,
image_transforms: Callable | None = None,
delta_timestamps: dict[str, list[float]] | None = None,
tolerance_s: float = 1e-4,
tolerance_s: float = 1e-2,
revision: str | None = None,
force_cache_sync: bool = False,
download_videos: bool = True,

View File

@@ -81,10 +81,14 @@ class AdamWConfig(OptimizerConfig):
eps: float = 1e-8
weight_decay: float = 1e-2
grad_clip_norm: float = 10.0
fused: bool = False
def build(self, params: dict) -> torch.optim.Optimizer:
kwargs = asdict(self)
kwargs.pop("grad_clip_norm")
# Fused optimizer only works on CUDA
if kwargs.get("fused") and not torch.cuda.is_available():
kwargs["fused"] = False
return torch.optim.AdamW(params, **kwargs)

View File

@@ -136,6 +136,7 @@ class ACTConfig(PreTrainedConfig):
optimizer_lr: float = 1e-5
optimizer_weight_decay: float = 1e-4
optimizer_lr_backbone: float = 1e-5
optimizer_fused: bool = False # Use CUDA fused AdamW kernel
def __post_init__(self):
super().__post_init__()
@@ -164,6 +165,7 @@ class ACTConfig(PreTrainedConfig):
return AdamWConfig(
lr=self.optimizer_lr,
weight_decay=self.optimizer_weight_decay,
fused=self.optimizer_fused,
)
def get_scheduler_preset(self) -> None:

View File

@@ -94,6 +94,7 @@ class GrootConfig(PreTrainedConfig):
optimizer_betas: tuple[float, float] = (0.95, 0.999)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-5
optimizer_fused: bool = False # Use CUDA fused AdamW kernel
warmup_ratio: float = 0.05
use_bf16: bool = True
@@ -174,6 +175,7 @@ class GrootConfig(PreTrainedConfig):
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
fused=self.optimizer_fused,
)
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:

View File

@@ -74,6 +74,7 @@ class PI0Config(PreTrainedConfig):
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
compile_model: bool = False # Whether to use torch.compile for model optimization
compile_mode: str = "max-autotune" # Torch compile mode
optimizer_fused: bool = False # Use CUDA fused AdamW kernel
device: str | None = None # Device to use for the model (None = auto-detect)
# Optimizer settings: see openpi `AdamW``
@@ -141,6 +142,7 @@ class PI0Config(PreTrainedConfig):
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
fused=self.optimizer_fused,
)
def get_scheduler_preset(self):

View File

@@ -74,6 +74,7 @@ class PI05Config(PreTrainedConfig):
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
compile_model: bool = False # Whether to use torch.compile for model optimization
compile_mode: str = "max-autotune" # Torch compile mode
optimizer_fused: bool = False # Use CUDA fused AdamW kernel
device: str | None = None # Device to use for the model (None = auto-detect)
# Optimizer settings: see openpi `AdamW`
@@ -141,6 +142,7 @@ class PI05Config(PreTrainedConfig):
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
fused=self.optimizer_fused,
)
def get_scheduler_preset(self):

View File

@@ -79,6 +79,7 @@ class SmolVLAConfig(PreTrainedConfig):
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-10
optimizer_grad_clip_norm: float = 10
optimizer_fused: bool = False
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
@@ -136,6 +137,7 @@ class SmolVLAConfig(PreTrainedConfig):
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
fused=self.optimizer_fused,
)
def get_scheduler_preset(self):