feat(policies): Add X-VLA (#2405)

* first commit

* more fixes

* add franka action

* update testing script

* add changes

* update files

* logits matching

* add imagenet as a norm type

* logits matching atol1e-2

* more eval fixes

* more changes

* xvla works on libero

* remove seed

* more refactoring

* more fixes

* more changes

* more changes

* more fixes

* migrate policy revert

* major pre-commit cleanup

* renaming

* revert to self.transformer

* refactor

* new changes

* clean

* update libero

* more changes

* make it work

* more changes:

* remove imagenet dependency

* style

* more

* more refactor

* remove proprio

* add loss

* more

* more

* add freeze/unfreeze options

* add testing

* upgrade transformers version

* update testing

* add installation

* remove .sh file

* fix testing

* silent linter in xvlatest

* fix failing test

* upgrade test, fix failing

* fix testing

* more fixes to testing

* require cuda in tests

* temp check

* add xvla docs

* fix styling

* update libero doc

* remove timm dep

* add different dtype support

* remove timm skip

* remove white lines

* Enhance X-VLA finetuning documentation with optimizer details (#2537)

Added detailed instructions for implementing a custom optimizer and modifying parameter retrieval for X-VLA finetuning.

Signed-off-by: Jinliang Zheng <54488861+2toinf@users.noreply.github.com>

* fix style

* iterate on review

* iterate on cpilot

* revert xvla dep

* free up ci

* test(xvla): remove main test (#2565)

* Add xvla custom optim and dtype (#2567)

* add custom optim

* add custom optim

* add auto mode

* more changes

* add identity to all

* add auto

* release

* add docs

* make image smaller docs

* smaller image in doc

* evan smaller image doc

* finalize doc

---------

Signed-off-by: Jinliang Zheng <54488861+2toinf@users.noreply.github.com>
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Jinliang Zheng <54488861+2toinf@users.noreply.github.com>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
Jade Choghari
2025-12-03 15:29:14 +01:00
committed by GitHub
parent b0b755471b
commit 43b0f17eb9
22 changed files with 6620 additions and 10 deletions

View File

@@ -104,6 +104,107 @@ class SGDConfig(OptimizerConfig):
return torch.optim.SGD(params, **kwargs)
@OptimizerConfig.register_subclass("xvla-adamw")
@dataclass
class XVLAAdamWConfig(OptimizerConfig):
"""Custom AdamW optimizer for XVLA with differential learning rates.
The Vision-Language Model (VLM) is trained with 1/10 of the base learning rate
for stable optimization, while all other components use the full LR.
This LR ratio is crucial for achieving strong and stable finetuning performance.
Soft-prompts can optionally use a separate learning rate with warm-up support.
Set `soft_prompt_lr_scale` to a value < 1.0 (e.g., 0.1) to start soft-prompts
at a lower LR. Combine with a warmup scheduler for optimal results.
Note:
Completely matching official reported performance may require an additional
warm-up LR schedule for soft-prompts, which can bring minor improvements.
When `soft_prompt_warmup_lr_scale` is set, soft-prompts start at
`lr * soft_prompt_warmup_lr_scale` and should be warmed up via the scheduler.
Parameter Groups:
- Group 0 (vlm): VLM parameters at lr * 0.1, weight_decay * 0.1
- Group 1 (soft_prompts): Soft-prompt parameters at lr * soft_prompt_lr_scale
- Group 2 (other): All other parameters at full lr
"""
lr: float = 1e-4
betas: tuple[float, float] = (0.9, 0.99)
eps: float = 1e-8
weight_decay: float = 0.0
grad_clip_norm: float = 10.0
# Soft-prompt specific settings
soft_prompt_lr_scale: float = 1.0 # Scale factor for soft-prompt LR (1.0 = same as base LR)
soft_prompt_warmup_lr_scale: float | None = None # If set, start soft-prompts at this scale (e.g., 0.01)
def build(self, params: dict) -> torch.optim.Optimizer:
"""
Build AdamW optimizer with differential learning rates.
Expects `named_parameters()` as input (dict of name -> param).
Applies:
- lr * 0.1 for all VLM-related parameters
- lr * soft_prompt_lr_scale for soft-prompt parameters (with optional warmup)
- full lr for all other parameters
Args:
params: Dictionary of parameter names to parameters (from named_parameters())
Returns:
AdamW optimizer with parameter groups for VLM, soft-prompts, and other components
"""
assert isinstance(params, dict), "Custom LR optimizer requires `named_parameters()` as inputs."
vlm_group, soft_prompt_group, other_group = [], [], []
for name, p in params.items():
if not p.requires_grad:
continue
if "vlm" in name.lower():
vlm_group.append(p)
elif "soft_prompt" in name.lower():
soft_prompt_group.append(p)
else:
other_group.append(p)
# Determine soft-prompt LR
soft_prompt_lr = self.lr * self.soft_prompt_lr_scale
if self.soft_prompt_warmup_lr_scale is not None:
# Start at warmup scale, scheduler will warm up to soft_prompt_lr
soft_prompt_lr = self.lr * self.soft_prompt_warmup_lr_scale
param_groups = [
{
"params": vlm_group,
"lr": self.lr * 0.1,
"weight_decay": self.weight_decay * 0.1,
"name": "vlm",
},
{
"params": soft_prompt_group,
"lr": soft_prompt_lr,
"weight_decay": self.weight_decay,
"name": "soft_prompts",
},
{
"params": other_group,
"lr": self.lr,
"weight_decay": self.weight_decay,
"name": "other",
},
]
# Filter out empty groups
param_groups = [g for g in param_groups if len(g["params"]) > 0]
return torch.optim.AdamW(
param_groups,
betas=self.betas,
eps=self.eps,
)
@OptimizerConfig.register_subclass("multi_adam")
@dataclass
class MultiAdamConfig(OptimizerConfig):