adding configuration gripper index and threshold

This commit is contained in:
Maximellerbach
2026-05-29 14:11:51 +02:00
parent d8ed30d58c
commit 61d82a5773
3 changed files with 29 additions and 11 deletions

View File

@@ -74,6 +74,10 @@ Key parameters in `VLAJEPAConfig`:
| `num_inference_timesteps` | 4 | Euler integration steps for action denoising |
| `freeze_qwen` | `False` | Freeze the Qwen3-VL backbone and only train the action head |
| `reinit_modules` | `None` | Key prefixes allowed to be randomly re-initialised on load (for cross-embodiment transfer, see [Fine-tuning on a different embodiment](#fine-tuning-on-a-different-embodiment)) |
| `gripper_dim` | 6 | Index of the gripper dimension in the action vector (e.g. 6 for a 7-DoF arm with gripper as the last joint) |
| `gripper_threshold` | 0.5 | Threshold used by `pre_snap_gripper_action` and `binarize_gripper_action` to binarize the gripper dimension |
| `pre_snap_gripper_action` | `True` | Snap the gripper dim to {0, 1} before unnormalization. Set to `False` for robots without a binary gripper |
| `binarize_gripper_action` | `True` | Binarize the gripper dim to {-1, 1} after unnormalization. Set to `False` for robots without a binary gripper |
---

View File

@@ -87,6 +87,8 @@ class VLAJEPAConfig(PreTrainedConfig):
binarize_gripper_action: bool = True
pre_snap_gripper_action: bool = True
clip_normalized_actions: bool = True
gripper_dim: int = 6
gripper_threshold: float = 0.5
torch_dtype: str = "bfloat16"
optimizer_lr: float = 1e-4

View File

@@ -53,21 +53,25 @@ class ClipActionsProcessorStep(ProcessorStep):
@ProcessorStepRegistry.register(name="vla_jepa_pre_snap_gripper")
class PreSnapGripperProcessorStep(ProcessorStep):
"""Snaps gripper dim (index 6) to {0, 1} BEFORE unnormalization.
"""Snaps a gripper dimension to {0, 1} BEFORE unnormalization.
Mirrors the original starVLA LIBERO eval:
normalized[:, 6] = np.where(normalized[:, 6] < 0.5, 0, 1)
normalized[:, gripper_dim] = np.where(normalized[:, gripper_dim] < threshold, 0, 1)
This ensures the unnormalizer receives an exact binary value, which is
required when the model was trained with gripper in identity (mask=False)
space where 0=open and 1=close.
"""
def __init__(self, gripper_dim: int = 6, threshold: float = 0.5):
self.gripper_dim = gripper_dim
self.threshold = threshold
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is not None and action.shape[-1] >= 7:
if action is not None and action.shape[-1] > self.gripper_dim:
transition = dict(transition)
a = action.clone()
a[..., 6] = (a[..., 6] >= 0.5).float()
a[..., self.gripper_dim] = (a[..., self.gripper_dim] >= self.threshold).float()
transition[TransitionKey.ACTION] = a
return transition
@@ -77,18 +81,22 @@ class PreSnapGripperProcessorStep(ProcessorStep):
@ProcessorStepRegistry.register(name="vla_jepa_binarize_gripper")
class BinarizeGripperProcessorStep(ProcessorStep):
"""Binarizes gripper dim (index 6) after unnormalization.
"""Binarizes a gripper dimension after unnormalization.
Maps continuous value to {-1, 1}: > 0.5 → -1, <= 0.5 → 1 (matches starVLA convention).
Only applied when action has >= 7 dimensions.
Maps continuous value to {-1, 1}: > threshold → -1, <= threshold → 1 (matches starVLA convention).
Only applied when action has more dimensions than gripper_dim.
"""
def __init__(self, gripper_dim: int = 6, threshold: float = 0.5):
self.gripper_dim = gripper_dim
self.threshold = threshold
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is not None and action.shape[-1] >= 7:
if action is not None and action.shape[-1] > self.gripper_dim:
transition = dict(transition)
a = action.clone()
a[..., 6] = 1.0 - 2.0 * (a[..., 6] > 0.5).float()
a[..., self.gripper_dim] = 1.0 - 2.0 * (a[..., self.gripper_dim] > self.threshold).float()
transition[TransitionKey.ACTION] = a
return transition
@@ -118,7 +126,9 @@ def make_vla_jepa_pre_post_processors(
if config.clip_normalized_actions:
output_steps.append(ClipActionsProcessorStep())
if config.pre_snap_gripper_action:
output_steps.append(PreSnapGripperProcessorStep())
output_steps.append(
PreSnapGripperProcessorStep(gripper_dim=config.gripper_dim, threshold=config.gripper_threshold)
)
output_steps.append(
UnnormalizerProcessorStep(
features=features,
@@ -127,7 +137,9 @@ def make_vla_jepa_pre_post_processors(
)
)
if config.binarize_gripper_action:
output_steps.append(BinarizeGripperProcessorStep())
output_steps.append(
BinarizeGripperProcessorStep(gripper_dim=config.gripper_dim, threshold=config.gripper_threshold)
)
output_steps.append(DeviceProcessorStep(device="cpu"))
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](