diff --git a/docs/source/vla_jepa.mdx b/docs/source/vla_jepa.mdx index 6d960bd56..ad37b2349 100644 --- a/docs/source/vla_jepa.mdx +++ b/docs/source/vla_jepa.mdx @@ -74,6 +74,10 @@ Key parameters in `VLAJEPAConfig`: | `num_inference_timesteps` | 4 | Euler integration steps for action denoising | | `freeze_qwen` | `False` | Freeze the Qwen3-VL backbone and only train the action head | | `reinit_modules` | `None` | Key prefixes allowed to be randomly re-initialised on load (for cross-embodiment transfer, see [Fine-tuning on a different embodiment](#fine-tuning-on-a-different-embodiment)) | +| `gripper_dim` | 6 | Index of the gripper dimension in the action vector (e.g. 6 for a 7-DoF arm with gripper as the last joint) | +| `gripper_threshold` | 0.5 | Threshold used by `pre_snap_gripper_action` and `binarize_gripper_action` to binarize the gripper dimension | +| `pre_snap_gripper_action` | `True` | Snap the gripper dim to {0, 1} before unnormalization. Set to `False` for robots without a binary gripper | +| `binarize_gripper_action` | `True` | Binarize the gripper dim to {-1, 1} after unnormalization. Set to `False` for robots without a binary gripper | --- diff --git a/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py b/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py index bdd3fb0f5..8a30ee374 100644 --- a/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py @@ -87,6 +87,8 @@ class VLAJEPAConfig(PreTrainedConfig): binarize_gripper_action: bool = True pre_snap_gripper_action: bool = True clip_normalized_actions: bool = True + gripper_dim: int = 6 + gripper_threshold: float = 0.5 torch_dtype: str = "bfloat16" optimizer_lr: float = 1e-4 diff --git a/src/lerobot/policies/vla_jepa/processor_vla_jepa.py b/src/lerobot/policies/vla_jepa/processor_vla_jepa.py index d9b17be57..b59cc0e90 100644 --- a/src/lerobot/policies/vla_jepa/processor_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/processor_vla_jepa.py @@ -53,21 +53,25 @@ class ClipActionsProcessorStep(ProcessorStep): @ProcessorStepRegistry.register(name="vla_jepa_pre_snap_gripper") class PreSnapGripperProcessorStep(ProcessorStep): - """Snaps gripper dim (index 6) to {0, 1} BEFORE unnormalization. + """Snaps a gripper dimension to {0, 1} BEFORE unnormalization. Mirrors the original starVLA LIBERO eval: - normalized[:, 6] = np.where(normalized[:, 6] < 0.5, 0, 1) + normalized[:, gripper_dim] = np.where(normalized[:, gripper_dim] < threshold, 0, 1) This ensures the unnormalizer receives an exact binary value, which is required when the model was trained with gripper in identity (mask=False) space where 0=open and 1=close. """ + def __init__(self, gripper_dim: int = 6, threshold: float = 0.5): + self.gripper_dim = gripper_dim + self.threshold = threshold + def __call__(self, transition: EnvTransition) -> EnvTransition: action = transition.get(TransitionKey.ACTION) - if action is not None and action.shape[-1] >= 7: + if action is not None and action.shape[-1] > self.gripper_dim: transition = dict(transition) a = action.clone() - a[..., 6] = (a[..., 6] >= 0.5).float() + a[..., self.gripper_dim] = (a[..., self.gripper_dim] >= self.threshold).float() transition[TransitionKey.ACTION] = a return transition @@ -77,18 +81,22 @@ class PreSnapGripperProcessorStep(ProcessorStep): @ProcessorStepRegistry.register(name="vla_jepa_binarize_gripper") class BinarizeGripperProcessorStep(ProcessorStep): - """Binarizes gripper dim (index 6) after unnormalization. + """Binarizes a gripper dimension after unnormalization. - Maps continuous value to {-1, 1}: > 0.5 → -1, <= 0.5 → 1 (matches starVLA convention). - Only applied when action has >= 7 dimensions. + Maps continuous value to {-1, 1}: > threshold → -1, <= threshold → 1 (matches starVLA convention). + Only applied when action has more dimensions than gripper_dim. """ + def __init__(self, gripper_dim: int = 6, threshold: float = 0.5): + self.gripper_dim = gripper_dim + self.threshold = threshold + def __call__(self, transition: EnvTransition) -> EnvTransition: action = transition.get(TransitionKey.ACTION) - if action is not None and action.shape[-1] >= 7: + if action is not None and action.shape[-1] > self.gripper_dim: transition = dict(transition) a = action.clone() - a[..., 6] = 1.0 - 2.0 * (a[..., 6] > 0.5).float() + a[..., self.gripper_dim] = 1.0 - 2.0 * (a[..., self.gripper_dim] > self.threshold).float() transition[TransitionKey.ACTION] = a return transition @@ -118,7 +126,9 @@ def make_vla_jepa_pre_post_processors( if config.clip_normalized_actions: output_steps.append(ClipActionsProcessorStep()) if config.pre_snap_gripper_action: - output_steps.append(PreSnapGripperProcessorStep()) + output_steps.append( + PreSnapGripperProcessorStep(gripper_dim=config.gripper_dim, threshold=config.gripper_threshold) + ) output_steps.append( UnnormalizerProcessorStep( features=features, @@ -127,7 +137,9 @@ def make_vla_jepa_pre_post_processors( ) ) if config.binarize_gripper_action: - output_steps.append(BinarizeGripperProcessorStep()) + output_steps.append( + BinarizeGripperProcessorStep(gripper_dim=config.gripper_dim, threshold=config.gripper_threshold) + ) output_steps.append(DeviceProcessorStep(device="cpu")) return ( PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](