mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-01 03:11:29 +00:00
Compare commits
16 Commits
exp/video-
...
refactor/l
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c4ff7dbf52 | ||
|
|
9e10eb4a77 | ||
|
|
2a6e6bef2b | ||
|
|
bbd1f1f920 | ||
|
|
5872b04851 | ||
|
|
80b0f1aaa2 | ||
|
|
0264ac717b | ||
|
|
94efcea867 | ||
|
|
faa276b8cf | ||
|
|
fd7cb9a5c5 | ||
|
|
25e3119a33 | ||
|
|
f49b8537ad | ||
|
|
ded91ca866 | ||
|
|
9ebc144b30 | ||
|
|
ba690632d9 | ||
|
|
eb2e79e22d |
@@ -99,6 +99,8 @@
|
||||
title: Unitree G1
|
||||
- local: earthrover_mini_plus
|
||||
title: Earth Rover Mini
|
||||
- local: omx
|
||||
title: OMX
|
||||
title: "Robots"
|
||||
- sections:
|
||||
- local: phone_teleop
|
||||
|
||||
197
docs/source/omx.mdx
Normal file
197
docs/source/omx.mdx
Normal file
@@ -0,0 +1,197 @@
|
||||
## Order and Assemble the parts
|
||||
|
||||
First, assemble the OMX hardware following the official assembly guide.
|
||||
|
||||
OMX Assembly Guide: https://ai.robotis.com/omx/assembly_guide_omx.html
|
||||
|
||||
OMX robots are shipped preconfigured from the factory. Motor IDs, communication parameters, and joint offsets are already set, so no additional motor setup or calibration is required before using LeRobot.
|
||||
|
||||
## Install LeRobot 🤗
|
||||
|
||||
To install LeRobot, follow our [Installation Guide](./installation)
|
||||
|
||||
In addition to these instructions, you need to install the Dynamixel SDK:
|
||||
|
||||
```bash
|
||||
pip install -e ".[dynamixel]"
|
||||
```
|
||||
|
||||
## Connect the robot
|
||||
|
||||
To find the port for each bus servo adapter, run this script:
|
||||
|
||||
```bash
|
||||
lerobot-find-port
|
||||
```
|
||||
|
||||
This command runs and when prompted, disconnect the USB cable from either the leader or follower arm and press Enter. The output will show 'The port of this MotorsBus is [port]'. This identifies the port for the disconnected arm. Repeat for the other arm to identify both ports.
|
||||
|
||||
<hfoptions id="find_port">
|
||||
<hfoption id="Mac">
|
||||
|
||||
Example output on macOS:
|
||||
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
Remove the USB cable from your MotorsBus and press Enter when done.
|
||||
|
||||
[...Disconnect corresponding leader or follower arm and press Enter...]
|
||||
|
||||
The port of this MotorsBus is /dev/tty.usbmodem575E0032081
|
||||
Reconnect the USB cable.
|
||||
```
|
||||
|
||||
Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your leader or follower arm.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Linux">
|
||||
|
||||
On Linux, we strongly recommend using udev rules to assign persistent and human-readable device names to the OMX leader and follower arms. This avoids issues where device names such as ttyACM0 and ttyACM1 change when the robot is unplugged, replugged, or when the system is rebooted.
|
||||
|
||||
#### 1. Find your device serial numbers
|
||||
|
||||
You should have obtained the port numbers like ../../ttyACM? for the leader and follower using `lerobot-find-port`. You can match those results with the serial numbers using the `ls -l /dev/serial/by-id/` command.
|
||||
To create udev rules, you need the unique serial number for each OMX device. The easiest way is to list devices under:
|
||||
|
||||
```bash
|
||||
ls -l /dev/serial/by-id/
|
||||
```
|
||||
|
||||
You will see output similar to:
|
||||
|
||||
```bash
|
||||
usb-ROBOTIS_OpenRB-150_228BDD7B503059384C2E3120FF0A2B19-if00 -> ../../ttyACM0
|
||||
usb-ROBOTIS_OpenRB-150_67E1ED68503059384C2E3120FF092234-if00 -> ../../ttyACM1
|
||||
```
|
||||
|
||||
In each line, the serial number is the long string after `usb-ROBOTIS_OpenRB-150_` and before `-if00`.
|
||||
|
||||
Follower serial: `228BDD7B503059384C2E3120FF0A2B19`
|
||||
|
||||
Leader serial: `67E1ED68503059384C2E3120FF092234`
|
||||
|
||||
#### 2. Create the udev rule
|
||||
|
||||
Create a new udev rule file:
|
||||
|
||||
```bash
|
||||
sudo nano /etc/udev/rules.d/99-omx.rules
|
||||
```
|
||||
|
||||
Paste the following lines, replacing the serial numbers with the values you found above:
|
||||
|
||||
```bash
|
||||
SUBSYSTEM=="tty", ATTRS{idVendor}=="0403", ATTRS{serial}=="228BDD7B503059384C2E3120FF0A2B19", SYMLINK+="omx_follower"
|
||||
SUBSYSTEM=="tty", ATTRS{idVendor}=="0403", ATTRS{serial}=="67E1ED68503059384C2E3120FF092234", SYMLINK+="omx_leader"
|
||||
```
|
||||
|
||||
Save the file and reload udev rules:
|
||||
|
||||
```bash
|
||||
sudo udevadm control --reload-rules
|
||||
sudo udevadm trigger
|
||||
```
|
||||
|
||||
Now unplug and replug both devices once.
|
||||
|
||||
#### 3. Verify the symlinks
|
||||
|
||||
Check that the persistent device names exist:
|
||||
|
||||
```bash
|
||||
ls -l /dev/omx_follower /dev/omx_leader
|
||||
```
|
||||
|
||||
You should see them pointing to ttyACM\* devices:
|
||||
|
||||
```bash
|
||||
/dev/omx_follower -> ttyACM*
|
||||
/dev/omx_leader -> ttyACM*
|
||||
```
|
||||
|
||||
These names remain stable across reboots and reconnections.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Teleoperate
|
||||
|
||||
After identifying the correct ports, you can directly teleoperate the follower arm using the leader arm.
|
||||
|
||||
<hfoptions id="teleoperate">
|
||||
<hfoption id="Mac">
|
||||
|
||||
### Teleoperate without camera
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=omx_follower \
|
||||
--robot.port=<your_follower_port> \
|
||||
--robot.id=omx_follower_arm \
|
||||
--teleop.type=omx_leader \
|
||||
--teleop.port=<your_leader_port> \
|
||||
--teleop.id=omx_leader_arm
|
||||
```
|
||||
|
||||
During teleoperation, motions of the leader arm are mirrored in real time by the follower arm. OMX is already preconfigured, teleoperation can begin immediately without any calibration steps.
|
||||
|
||||
### Teleoperate with camera
|
||||
|
||||
You can also enable camera input during teleoperation by providing a camera configuration for the follower arm.
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=omx_follower \
|
||||
--robot.port=<your_follower_port> \
|
||||
--robot.id=omx_follower_arm \
|
||||
--robot.cameras="{front: {type: opencv, index_or_path: '/dev/video0', width: 640, height: 480, fps: 30}}" \
|
||||
--teleop.type=omx_leader \
|
||||
--teleop.port=<your_leader_port> \
|
||||
--teleop.id=omx_leader_arm \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
When the camera is enabled, the camera stream is displayed in real time and synchronized with the robot state. This setup is useful for visual monitoring and can be reused later for demonstration recording and imitation learning.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Linux">
|
||||
|
||||
### Teleoperate without camera
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=omx_follower \
|
||||
--robot.port=/dev/omx_follower \
|
||||
--robot.id=omx_follower_arm \
|
||||
--teleop.type=omx_leader \
|
||||
--teleop.port=/dev/omx_leader \
|
||||
--teleop.id=omx_leader_arm
|
||||
```
|
||||
|
||||
During teleoperation, motions of the leader arm are mirrored in real time by the follower arm. OMX is already preconfigured, teleoperation can begin immediately without any calibration steps.
|
||||
|
||||
### Teleoperate with camera
|
||||
|
||||
You can also enable camera input during teleoperation by providing a camera configuration for the follower arm.
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=omx_follower \
|
||||
--robot.port=/dev/omx_follower \
|
||||
--robot.id=omx_follower_arm \
|
||||
--robot.cameras="{front: {type: opencv, index_or_path: '/dev/video0', width: 640, height: 480, fps: 30}}" \
|
||||
--teleop.type=omx_leader \
|
||||
--teleop.port=/dev/omx_leader \
|
||||
--teleop.id=omx_leader_arm \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
When the camera is enabled, the camera stream is displayed in real time and synchronized with the robot state. This setup is useful for visual monitoring and can be reused later for demonstration recording and imitation learning.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Congrats 🎉, your robot is all set to learn a task on its own.
|
||||
|
||||
> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/robotis).
|
||||
@@ -465,15 +465,15 @@ This script:
|
||||
|
||||
### Step 5b: Train Policy with RA-BC
|
||||
|
||||
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
|
||||
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`) if not explicitly provided. Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=pi0 \
|
||||
--use_rabc=true \
|
||||
--rabc_head_mode=sparse \
|
||||
--rabc_kappa=0.01 \
|
||||
--sample_weighting.type=rabc \
|
||||
--sample_weighting.head_mode=sparse \
|
||||
--sample_weighting.kappa=0.01 \
|
||||
--output_dir=outputs/train/policy_rabc \
|
||||
--batch_size=32 \
|
||||
--steps=40000
|
||||
@@ -488,12 +488,13 @@ The training script automatically:
|
||||
|
||||
**RA-BC Arguments:**
|
||||
|
||||
| Argument | Description | Default |
|
||||
| ---------------------- | ---------------------------------------------------------- | ---------------------------------- |
|
||||
| `--use_rabc` | Enable RA-BC sample weighting | `false` |
|
||||
| `--rabc_progress_path` | Path to progress parquet file (auto-detected from dataset) | `sarm_progress.parquet` in dataset |
|
||||
| `--rabc_head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
|
||||
| `--rabc_kappa` | Threshold κ for high-quality samples | `0.01` |
|
||||
| Argument | Description | Default |
|
||||
| ---------------------------------- | ------------------------------------------------------ | ----------------------- |
|
||||
| `--sample_weighting.type` | Weighting strategy type (`rabc` or `uniform`) | `rabc` |
|
||||
| `--sample_weighting.progress_path` | Path to progress parquet file | `sarm_progress.parquet` |
|
||||
| `--sample_weighting.head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
|
||||
| `--sample_weighting.kappa` | Threshold κ for high-quality samples | `0.01` |
|
||||
| `--sample_weighting.epsilon` | Small constant for numerical stability | `1e-6` |
|
||||
|
||||
### Tuning RA-BC Kappa
|
||||
|
||||
@@ -511,30 +512,30 @@ The `kappa` parameter is the threshold that determines which samples get full we
|
||||
|
||||
Monitor these WandB metrics during training:
|
||||
|
||||
| Metric | Healthy Range | Problem Indicator |
|
||||
| ------------------ | ------------- | ------------------------- |
|
||||
| `rabc_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
|
||||
| `rabc_delta_mean` | > 0 | Should be positive |
|
||||
| `rabc_delta_std` | > 0 | Variance in data quality |
|
||||
| Metric | Healthy Range | Problem Indicator |
|
||||
| ----------------------------- | ------------- | ------------------------- |
|
||||
| `sample_weight_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
|
||||
| `sample_weighting/delta_mean` | > 0 | Should be positive |
|
||||
| `sample_weighting/delta_std` | > 0 | Variance in data quality |
|
||||
|
||||
**If `rabc_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
|
||||
**If `sample_weight_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
|
||||
|
||||
**Setting kappa based on your data:**
|
||||
|
||||
The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `rabc_delta_mean` and `rabc_delta_std`:
|
||||
The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `sample_weighting/delta_mean` and `sample_weighting/delta_std`:
|
||||
|
||||
```
|
||||
# If delta_mean ≈ 0.03 and delta_std ≈ 0.02:
|
||||
# Most deltas fall in range [0.01, 0.05]
|
||||
|
||||
# Option 1: Set kappa = delta_mean (medium selectivity)
|
||||
--rabc_kappa=0.03
|
||||
--sample_weighting.kappa=0.03
|
||||
|
||||
# Option 2: Set kappa = delta_mean + delta_std (high selectivity)
|
||||
--rabc_kappa=0.05
|
||||
--sample_weighting.kappa=0.05
|
||||
|
||||
# Option 3: Set kappa = delta_mean + 2*delta_std (very selective)
|
||||
--rabc_kappa=0.07
|
||||
--sample_weighting.kappa=0.07
|
||||
```
|
||||
|
||||
**When RA-BC may not help:**
|
||||
@@ -550,8 +551,8 @@ accelerate launch \
|
||||
src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=pi0 \
|
||||
--use_rabc=true \
|
||||
--rabc_kappa=0.01 \
|
||||
--sample_weighting.type=rabc \
|
||||
--sample_weighting.kappa=0.01 \
|
||||
--output_dir=outputs/train/policy_rabc \
|
||||
--batch_size=32 \
|
||||
--steps=40000
|
||||
@@ -576,7 +577,7 @@ accelerate launch \
|
||||
### RA-BC
|
||||
|
||||
1. **Train SARM first**: RA-BC quality depends entirely on SARM quality
|
||||
2. **Monitor `rabc_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa))
|
||||
2. **Monitor `sample_weight_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa))
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -105,16 +105,6 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
def observation_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def image_observation_delta_indices(self) -> list | None: # type: ignore[type-arg]
|
||||
"""Return indices for delta image observations only.
|
||||
|
||||
Unlike observation_delta_indices which applies to ALL observations,
|
||||
this only applies to image observations (keys starting with observation.images).
|
||||
Default returns None. Override in subclass to enable.
|
||||
"""
|
||||
return None
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def action_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
|
||||
|
||||
@@ -29,6 +29,7 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.optim import OptimizerConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.sample_weighting import SampleWeightingConfig
|
||||
|
||||
TRAIN_CONFIG_NAME = "train_config.json"
|
||||
|
||||
@@ -67,12 +68,8 @@ class TrainPipelineConfig(HubMixin):
|
||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||
peft: PeftConfig | None = None
|
||||
|
||||
# RA-BC (Reward-Aligned Behavior Cloning) parameters
|
||||
use_rabc: bool = False # Enable reward-weighted training
|
||||
rabc_progress_path: str | None = None # Path to precomputed SARM progress parquet file
|
||||
rabc_kappa: float = 0.01 # Hard threshold for high-quality samples
|
||||
rabc_epsilon: float = 1e-6 # Small constant for numerical stability
|
||||
rabc_head_mode: str | None = "sparse" # For dual-head models: "sparse" or "dense"
|
||||
# Sample weighting configuration (e.g., for RA-BC training)
|
||||
sample_weighting: SampleWeightingConfig | None = None
|
||||
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
@@ -140,14 +137,6 @@ class TrainPipelineConfig(HubMixin):
|
||||
"'policy.repo_id' argument missing. Please specify it to push the model to the hub."
|
||||
)
|
||||
|
||||
if self.use_rabc and not self.rabc_progress_path:
|
||||
# Auto-detect from dataset path
|
||||
repo_id = self.dataset.repo_id
|
||||
if self.dataset.root:
|
||||
self.rabc_progress_path = str(Path(self.dataset.root) / "sarm_progress.parquet")
|
||||
else:
|
||||
self.rabc_progress_path = f"hf://datasets/{repo_id}/sarm_progress.parquet"
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
|
||||
@@ -27,7 +27,7 @@ from lerobot.datasets.lerobot_dataset import (
|
||||
)
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.datasets.transforms import ImageTransforms
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_PREFIX, REWARD
|
||||
from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD
|
||||
|
||||
IMAGENET_STATS = {
|
||||
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
|
||||
@@ -59,12 +59,7 @@ def resolve_delta_timestamps(
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
|
||||
if key == ACTION and cfg.action_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
|
||||
|
||||
# Check for image-specific delta indices first (e.g., for video encoding)
|
||||
if key.startswith(OBS_IMAGES) and cfg.image_observation_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.image_observation_delta_indices]
|
||||
# Fall back to generic observation delta indices for all observations
|
||||
elif key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None:
|
||||
if key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
|
||||
|
||||
if len(delta_timestamps) == 0:
|
||||
|
||||
@@ -35,7 +35,6 @@ from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.videovla.configuration_pi05 import PI05VideoConfig
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||
@@ -68,7 +67,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
|
||||
Args:
|
||||
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
||||
"vqbet", "pi0", "pi05", "pi05_video", "sac", "reward_classifier", "smolvla", "wall_x".
|
||||
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
|
||||
|
||||
Returns:
|
||||
The policy class corresponding to the given name.
|
||||
@@ -104,10 +103,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from lerobot.policies.pi05.modeling_pi05 import PI05Policy
|
||||
|
||||
return PI05Policy
|
||||
elif name == "pi05_video":
|
||||
from lerobot.policies.videovla.modeling_pi05 import PI05VideoPolicy
|
||||
|
||||
return PI05VideoPolicy
|
||||
elif name == "sac":
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
|
||||
@@ -152,7 +147,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
|
||||
Args:
|
||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||
"diffusion", "act", "vqbet", "pi0", "pi05", "pi05_video", "sac", "smolvla",
|
||||
"diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla",
|
||||
"reward_classifier", "wall_x".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
@@ -174,8 +169,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi05":
|
||||
return PI05Config(**kwargs)
|
||||
elif policy_type == "pi05_video":
|
||||
return PI05VideoConfig(**kwargs)
|
||||
elif policy_type == "sac":
|
||||
return SACConfig(**kwargs)
|
||||
elif policy_type == "smolvla":
|
||||
@@ -340,14 +333,6 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, PI05VideoConfig):
|
||||
from lerobot.policies.videovla.processor_pi05 import make_pi05_video_pre_post_processors
|
||||
|
||||
processors = make_pi05_video_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SACConfig):
|
||||
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
|
||||
|
||||
|
||||
@@ -460,8 +460,8 @@ class PaliGemmaWithExpertModel(
|
||||
inputs_embeds=inputs_embeds[1],
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
use_cache=False,
|
||||
past_key_values=None, #jadechoghari
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
adarms_cond=adarms_cond[1] if adarms_cond is not None else None,
|
||||
)
|
||||
suffix_output = suffix_output.last_hidden_state
|
||||
@@ -575,13 +575,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
|
||||
|
||||
# try:
|
||||
# from transformers.models.siglip import check
|
||||
try:
|
||||
from transformers.models.siglip import check
|
||||
|
||||
# if not check.check_whether_transformers_replace_is_installed_correctly():
|
||||
# raise ValueError(msg)
|
||||
# except ImportError:
|
||||
# raise ValueError(msg) from None
|
||||
if not check.check_whether_transformers_replace_is_installed_correctly():
|
||||
raise ValueError(msg)
|
||||
except ImportError:
|
||||
raise ValueError(msg) from None
|
||||
|
||||
def gradient_checkpointing_enable(self):
|
||||
"""Enable gradient checkpointing for memory optimization."""
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -14,6 +12,22 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
RA-BC (Reward-Aligned Behavior Cloning) sample weighting implementation.
|
||||
|
||||
This module implements the SampleWeighter protocol for RA-BC training,
|
||||
which weights training samples based on their task progress as measured
|
||||
by the SARM reward model.
|
||||
|
||||
The weights are computed based on progress deltas:
|
||||
delta = progress[t + chunk_size] - progress[t]
|
||||
|
||||
High-quality samples (positive progress) get higher weights, while
|
||||
samples with negative progress (going backwards) get zero weight.
|
||||
|
||||
See: https://arxiv.org/abs/2509.25358 for the SARM paper.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
@@ -22,6 +36,8 @@ import pandas as pd
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from lerobot.utils.sample_weighting import SampleWeighter
|
||||
|
||||
|
||||
def resolve_hf_path(path: str | Path) -> Path:
|
||||
"""Resolve a path that may be a HuggingFace URL (hf://datasets/...) to a local path."""
|
||||
@@ -34,23 +50,27 @@ def resolve_hf_path(path: str | Path) -> Path:
|
||||
return Path(path)
|
||||
|
||||
|
||||
class RABCWeights:
|
||||
class RABCWeights(SampleWeighter):
|
||||
"""
|
||||
Load precomputed SARM progress values and compute RA-BC weights during training.
|
||||
|
||||
This class implements the SampleWeighter ABC for use with the generic
|
||||
sample weighting infrastructure in lerobot.
|
||||
|
||||
Progress values are loaded from a parquet file (generated by compute_rabc_weights.py).
|
||||
During training, computes:
|
||||
- progress_delta = progress[t + chunk_size] - progress[t]
|
||||
- rabc_weight based on the delta (paper Eq. 8-9)
|
||||
|
||||
Args:
|
||||
progress_path: Path to parquet file with precomputed progress values
|
||||
chunk_size: Number of frames ahead for computing progress delta
|
||||
head_mode: Which SARM head to use ("sparse" or "dense")
|
||||
kappa: Hard threshold for high-quality samples (default: 0.01)
|
||||
epsilon: Small constant for numerical stability (default: 1e-6)
|
||||
fallback_weight: Weight to use for frames without valid delta (default: 1.0)
|
||||
device: Device to return tensors on
|
||||
progress_path: Path to parquet file with precomputed progress values.
|
||||
Supports HuggingFace URLs (hf://datasets/...).
|
||||
chunk_size: Number of frames ahead for computing progress delta.
|
||||
head_mode: Which SARM head to use ("sparse" or "dense").
|
||||
kappa: Hard threshold for high-quality samples (default: 0.01).
|
||||
epsilon: Small constant for numerical stability (default: 1e-6).
|
||||
fallback_weight: Weight to use for frames without valid delta (default: 1.0).
|
||||
device: Device to return tensors on.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -61,7 +81,7 @@ class RABCWeights:
|
||||
kappa: float = 0.01,
|
||||
epsilon: float = 1e-6,
|
||||
fallback_weight: float = 1.0,
|
||||
device: torch.device = None,
|
||||
device: torch.device | None = None,
|
||||
):
|
||||
self.progress_path = resolve_hf_path(progress_path)
|
||||
self.chunk_size = chunk_size
|
||||
@@ -87,8 +107,8 @@ class RABCWeights:
|
||||
|
||||
logging.info(f"Using progress column: {self.progress_column}")
|
||||
|
||||
self.progress_lookup = {}
|
||||
self.episode_lookup = {}
|
||||
self.progress_lookup: dict[int, float] = {}
|
||||
self.episode_lookup: dict[int, int] = {}
|
||||
|
||||
for _, row in self.df.iterrows():
|
||||
global_idx = int(row["index"])
|
||||
@@ -100,7 +120,7 @@ class RABCWeights:
|
||||
self.episode_lookup[global_idx] = episode_idx
|
||||
|
||||
# Build episode boundaries for delta computation
|
||||
self.episode_boundaries = {}
|
||||
self.episode_boundaries: dict[int, dict[str, int]] = {}
|
||||
for episode_idx in self.df["episode_index"].unique():
|
||||
ep_df = self.df[self.df["episode_index"] == episode_idx]
|
||||
self.episode_boundaries[int(episode_idx)] = {
|
||||
@@ -114,7 +134,7 @@ class RABCWeights:
|
||||
# Compute global statistics for weight computation
|
||||
self._compute_global_stats()
|
||||
|
||||
def _compute_global_stats(self):
|
||||
def _compute_global_stats(self) -> None:
|
||||
"""Compute global mean and std of progress deltas for weight calculation."""
|
||||
all_deltas = []
|
||||
|
||||
@@ -138,8 +158,8 @@ class RABCWeights:
|
||||
all_deltas.append(delta)
|
||||
|
||||
if all_deltas:
|
||||
self.delta_mean = max(np.mean(all_deltas), 0.0)
|
||||
self.delta_std = max(np.std(all_deltas), self.epsilon)
|
||||
self.delta_mean = max(float(np.mean(all_deltas)), 0.0)
|
||||
self.delta_std = max(float(np.std(all_deltas)), self.epsilon)
|
||||
logging.info(f"Progress delta stats: mean={self.delta_mean:.4f}, std={self.delta_std:.4f}")
|
||||
else:
|
||||
self.delta_mean = 0.0
|
||||
@@ -157,18 +177,19 @@ class RABCWeights:
|
||||
4. Compute weight using paper Eq. 8-9
|
||||
|
||||
Args:
|
||||
batch: Training batch containing "index" key with global frame indices
|
||||
batch: Training batch containing "index" key with global frame indices.
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- Weights tensor (batch_size,) normalized to sum to batch_size
|
||||
- Stats dict with raw_mean_weight, num_zero_weight, num_full_weight
|
||||
- Weights tensor (batch_size,) normalized to sum to batch_size.
|
||||
- Stats dict with weighting statistics for logging.
|
||||
"""
|
||||
indices = batch.get("index")
|
||||
if indices is None:
|
||||
logging.warning("RA-BC: Batch missing 'index' key, using uniform weights")
|
||||
batch_size = self._get_batch_size(batch)
|
||||
return torch.ones(batch_size, device=self.device), {"raw_mean_weight": 1.0}
|
||||
stats = {"mean_weight": 1.0, "num_zero_weight": 0, "num_full_weight": batch_size}
|
||||
return torch.ones(batch_size, device=self.device), stats
|
||||
|
||||
# Convert to list of ints
|
||||
if isinstance(indices, torch.Tensor):
|
||||
@@ -183,29 +204,29 @@ class RABCWeights:
|
||||
delta = self._compute_delta(idx)
|
||||
deltas.append(delta)
|
||||
|
||||
deltas = np.array(deltas, dtype=np.float32)
|
||||
deltas_array = np.array(deltas, dtype=np.float32)
|
||||
|
||||
# Compute weights from deltas
|
||||
weights = self._compute_weights(deltas)
|
||||
weights = self._compute_weights(deltas_array)
|
||||
|
||||
# Compute stats before normalization for logging
|
||||
raw_mean_weight = float(np.nanmean(weights))
|
||||
num_zero_weight = int(np.sum(weights == 0))
|
||||
num_full_weight = int(np.sum(weights == 1.0))
|
||||
batch_stats = {
|
||||
"raw_mean_weight": raw_mean_weight,
|
||||
"mean_weight": raw_mean_weight,
|
||||
"num_zero_weight": num_zero_weight,
|
||||
"num_full_weight": num_full_weight,
|
||||
}
|
||||
|
||||
weights = torch.tensor(weights, device=self.device, dtype=torch.float32)
|
||||
weights_tensor = torch.tensor(weights, device=self.device, dtype=torch.float32)
|
||||
|
||||
# Normalize to sum to batch_size
|
||||
batch_size = len(weights)
|
||||
weight_sum = weights.sum() + self.epsilon
|
||||
weights = weights * batch_size / weight_sum
|
||||
batch_size = len(weights_tensor)
|
||||
weight_sum = weights_tensor.sum() + self.epsilon
|
||||
weights_tensor = weights_tensor * batch_size / weight_sum
|
||||
|
||||
return weights, batch_stats
|
||||
return weights_tensor, batch_stats
|
||||
|
||||
def _compute_delta(self, global_idx: int) -> float:
|
||||
"""Compute progress delta for a single frame."""
|
||||
@@ -241,7 +262,7 @@ class RABCWeights:
|
||||
- Final weight: wi = 1{ri > κ} + 1{0 ≤ ri ≤ κ}˜wi
|
||||
|
||||
Returns:
|
||||
Array of weights
|
||||
Array of weights.
|
||||
"""
|
||||
valid_mask = ~np.isnan(deltas)
|
||||
|
||||
@@ -273,12 +294,13 @@ class RABCWeights:
|
||||
if key in batch:
|
||||
val = batch[key]
|
||||
if isinstance(val, (torch.Tensor, np.ndarray)):
|
||||
return val.shape[0]
|
||||
return int(val.shape[0])
|
||||
return 1
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get statistics."""
|
||||
"""Get global statistics about the RA-BC weighting."""
|
||||
return {
|
||||
"type": "rabc",
|
||||
"num_frames": len(self.progress_lookup),
|
||||
"chunk_size": self.chunk_size,
|
||||
"head_mode": self.head_mode,
|
||||
@@ -1,49 +0,0 @@
|
||||
# π₀.₅ (pi05)
|
||||
|
||||
This repository contains the Hugging Face port of **π₀.₅**, adapted from [OpenPI](https://github.com/Physical-Intelligence/openpi) by the Physical Intelligence.
|
||||
It is designed as a **Vision-Language-Action model with open-world generalization**.
|
||||
|
||||
---
|
||||
|
||||
## Model Overview
|
||||
|
||||
| Feature | π₀ | π₀.₅ |
|
||||
| -------------------- | ------------------------------------------------------ | ----------------------------------------- |
|
||||
| Time Conditioning | Concatenates time with actions via `action_time_mlp_*` | Uses `time_mlp_*` for AdaRMS conditioning |
|
||||
| AdaRMS | Not used | Used in action expert |
|
||||
| Tokenizer Length | 48 tokens | 200 tokens |
|
||||
| Discrete State Input | False (Uses `state_proj` layer) | True |
|
||||
| Parameter Count | Higher (includes state embedding) | Lower (no state embedding) |
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this work, please cite both **OpenPI** and the π₀.₅ paper:
|
||||
|
||||
```bibtex
|
||||
@misc{openpi2024,
|
||||
author = {Physical Intelligence Lab},
|
||||
title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies},
|
||||
year = {2024},
|
||||
publisher = {GitHub},
|
||||
howpublished = {\url{https://github.com/Physical-Intelligence/openpi}},
|
||||
license = {Apache-2.0}
|
||||
}
|
||||
|
||||
@misc{intelligence2025pi05visionlanguageactionmodelopenworld,
|
||||
title = {π₀.₅: a Vision-Language-Action Model with Open-World Generalization},
|
||||
author = {Physical Intelligence and Kevin Black and Noah Brown and James Darpinian and Karan Dhabalia and Danny Driess and Adnan Esmail and Michael Equi and Chelsea Finn and Niccolo Fusai and Manuel Y. Galliker and Dibya Ghosh and Lachy Groom and Karol Hausman and Brian Ichter and Szymon Jakubczak and Tim Jones and Liyiming Ke and Devin LeBlanc and Sergey Levine and Adrian Li-Bell and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Allen Z. Ren and Lucy Xiaoyang Shi and Laura Smith and Jost Tobias Springenberg and Kyle Stachowicz and James Tanner and Quan Vuong and Homer Walke and Anna Walling and Haohuan Wang and Lili Yu and Ury Zhilinsky},
|
||||
year = {2025},
|
||||
eprint = {2504.16054},
|
||||
archivePrefix= {arXiv},
|
||||
primaryClass = {cs.LG},
|
||||
url = {https://arxiv.org/abs/2504.16054},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
This port follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
|
||||
@@ -1,31 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lazy imports to avoid conflicts with lerobot.policies.pi05.PI05Config
|
||||
# when only importing subpackages like videoprism
|
||||
def __getattr__(name):
|
||||
if name == "PI05VideoConfig":
|
||||
from .configuration_pi05 import PI05VideoConfig
|
||||
return PI05VideoConfig
|
||||
elif name == "PI05VideoPolicy":
|
||||
from .modeling_pi05 import PI05VideoPolicy
|
||||
return PI05VideoPolicy
|
||||
elif name == "make_pi05_video_pre_post_processors":
|
||||
from .processor_pi05 import make_pi05_video_pre_post_processors
|
||||
return make_pi05_video_pre_post_processors
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
__all__ = ["PI05VideoConfig", "PI05VideoPolicy", "make_pi05_video_pre_post_processors"]
|
||||
@@ -1,212 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
DEFAULT_IMAGE_SIZE = 224
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi05_video")
|
||||
@dataclass
|
||||
class PI05VideoConfig(PreTrainedConfig):
|
||||
paligemma_variant: str = "gemma_2b"
|
||||
action_expert_variant: str = "gemma_300m"
|
||||
dtype: str = "float32" # Options: "bfloat16", "float32"
|
||||
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon"
|
||||
n_action_steps: int = 50 # Number of action steps to execute
|
||||
|
||||
# Video encoder settings (VideoPrism)
|
||||
use_video_encoder: bool = False # Enable video encoding with VideoPrism
|
||||
video_num_frames: int = 16 # Number of frames for video encoding (VideoPrism default is 16)
|
||||
videoprism_model_name: str = "MHRDYN7/videoprism-base-f16r288" # VideoPrism model to use
|
||||
videoprism_image_size: int = 288 # VideoPrism expects 288x288 images
|
||||
freeze_video_encoder: bool = True # Whether to freeze the video encoder weights
|
||||
video_padding_mode: str = "repeat" # How to pad frames at episode start: "repeat" or "zero"
|
||||
# Which camera to use for video encoding (None = first camera, or specify key like "observation.images.top")
|
||||
video_encoder_camera_key: str | None = None
|
||||
# Perceiver Resampler settings to reduce video tokens (4096 -> video_num_latents)
|
||||
video_num_latents: int = 256 # Number of latent tokens for video resampler
|
||||
video_resampler_num_heads: int = 8 # Number of attention heads in resampler
|
||||
|
||||
# Shorter state and action vectors will be padded to these dimensions
|
||||
max_state_dim: int = 32
|
||||
max_action_dim: int = 32
|
||||
|
||||
# Flow matching parameters: see openpi `PI0Pytorch`
|
||||
num_inference_steps: int = 10
|
||||
time_sampling_beta_alpha: float = 1.5
|
||||
time_sampling_beta_beta: float = 1.0
|
||||
time_sampling_scale: float = 0.999
|
||||
time_sampling_offset: float = 0.001
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
image_resolution: tuple[int, int] = (
|
||||
DEFAULT_IMAGE_SIZE,
|
||||
DEFAULT_IMAGE_SIZE,
|
||||
) # see openpi `preprocessing_pytorch.py`
|
||||
|
||||
# Add empty images. Used to add empty cameras when no image features are present.
|
||||
empty_cameras: int = 0
|
||||
|
||||
tokenizer_max_length: int = 200 # see openpi `__post_init__`
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for state
|
||||
"ACTION": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for action
|
||||
}
|
||||
)
|
||||
|
||||
# Training settings
|
||||
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
|
||||
compile_model: bool = False # Whether to use torch.compile for model optimization
|
||||
compile_mode: str = "max-autotune" # Torch compile mode
|
||||
device: str | None = None # Device to use for the model (None = auto-detect)
|
||||
|
||||
# Finetuning settings
|
||||
freeze_vision_encoder: bool = False # Freeze only the vision encoder
|
||||
train_expert_only: bool = False # Freeze entire VLM, train only action expert and projections
|
||||
|
||||
# Optimizer settings: see openpi `AdamW`
|
||||
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 0.01
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
|
||||
# Scheduler settings: see openpi `CosineDecaySchedule`
|
||||
# Note: These will auto-scale if --steps < scheduler_decay_steps
|
||||
# For example, --steps=3000 will scale warmup to 100 and decay to 3000
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
tokenizer_max_length: int = 200 # see openpi `__post_init__`
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
# Validate configuration
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
|
||||
)
|
||||
|
||||
if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]:
|
||||
raise ValueError(f"Invalid paligemma_variant: {self.paligemma_variant}")
|
||||
|
||||
if self.action_expert_variant not in ["gemma_300m", "gemma_2b"]:
|
||||
raise ValueError(f"Invalid action_expert_variant: {self.action_expert_variant}")
|
||||
|
||||
if self.dtype not in ["bfloat16", "float32"]:
|
||||
raise ValueError(f"Invalid dtype: {self.dtype}")
|
||||
|
||||
# Validate video encoder settings
|
||||
if self.use_video_encoder:
|
||||
if self.video_num_frames < 1:
|
||||
raise ValueError(f"video_num_frames must be >= 1, got {self.video_num_frames}")
|
||||
if self.videoprism_image_size < 1:
|
||||
raise ValueError(f"videoprism_image_size must be >= 1, got {self.videoprism_image_size}")
|
||||
if self.video_padding_mode not in ["repeat", "zero"]:
|
||||
raise ValueError(
|
||||
f"video_padding_mode must be 'repeat' or 'zero', got {self.video_padding_mode}"
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate and set up input/output features."""
|
||||
for i in range(self.empty_cameras):
|
||||
key = OBS_IMAGES + f".empty_camera_{i}"
|
||||
empty_camera = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, *self.image_resolution), # Use configured image resolution
|
||||
)
|
||||
self.input_features[key] = empty_camera
|
||||
|
||||
if OBS_STATE not in self.input_features:
|
||||
state_feature = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.max_state_dim,), # Padded to max_state_dim
|
||||
)
|
||||
self.input_features[OBS_STATE] = state_feature
|
||||
|
||||
if ACTION not in self.output_features:
|
||||
action_feature = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(self.max_action_dim,), # Padded to max_action_dim
|
||||
)
|
||||
self.output_features[ACTION] = action_feature
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int] | None:
|
||||
"""Return indices for delta observations.
|
||||
|
||||
For PI05, we don't use generic observation_delta_indices because it would
|
||||
apply to both images AND state. Instead, we use image_observation_delta_indices
|
||||
which only applies to image observations.
|
||||
"""
|
||||
return None
|
||||
|
||||
@property
|
||||
def image_observation_delta_indices(self) -> list[int] | None:
|
||||
"""Return indices for delta image observations only.
|
||||
|
||||
When video encoding is enabled, returns indices for the past frames
|
||||
needed by VideoPrism (e.g., -15, -14, ..., -1, 0 for 16 frames).
|
||||
This only applies to image observations, not state.
|
||||
"""
|
||||
if self.use_video_encoder:
|
||||
# Return indices for past frames: [-15, -14, ..., -1, 0] for 16 frames
|
||||
return list(range(-(self.video_num_frames - 1), 1))
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,171 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.videovla.configuration_pi05 import PI05VideoConfig
|
||||
from lerobot.policies.pi05.modeling_pi05 import pad_vector
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
OBS_STATE,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="pi05_prepare_state_tokenizer_processor_step")
|
||||
@dataclass
|
||||
class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
|
||||
"""
|
||||
Processor step to prepare the state and tokenize the language input.
|
||||
"""
|
||||
|
||||
max_state_dim: int = 32
|
||||
task_key: str = "task"
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
transition = transition.copy()
|
||||
|
||||
state = transition.get(TransitionKey.OBSERVATION, {}).get(OBS_STATE)
|
||||
if state is None:
|
||||
raise ValueError("State is required for PI05")
|
||||
tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key)
|
||||
if tasks is None:
|
||||
raise ValueError("No task found in complementary data")
|
||||
|
||||
# TODO: check if this necessary
|
||||
state = deepcopy(state)
|
||||
|
||||
# Prepare state (pad to max_state_dim)
|
||||
state = pad_vector(state, self.max_state_dim)
|
||||
|
||||
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
|
||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
state_np = state.cpu().numpy()
|
||||
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
|
||||
full_prompts = []
|
||||
for i, task in enumerate(tasks):
|
||||
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
||||
state_str = " ".join(map(str, discretized_states[i]))
|
||||
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
|
||||
full_prompts.append(full_prompt)
|
||||
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts
|
||||
# Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!)
|
||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
return transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
This step does not alter the feature definitions.
|
||||
"""
|
||||
return features
|
||||
|
||||
|
||||
def make_pi05_video_pre_post_processors(
|
||||
config: PI05VideoConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for the PI05Video policy.
|
||||
|
||||
The pre-processing pipeline prepares input data for the model by:
|
||||
1. Renaming features to match pretrained configurations.
|
||||
2. Normalizing input and output features based on dataset statistics.
|
||||
3. Adding a batch dimension.
|
||||
4. Appending a newline character to the task description for tokenizer compatibility.
|
||||
5. Tokenizing the text prompt using the PaliGemma tokenizer.
|
||||
6. Moving all data to the specified device.
|
||||
|
||||
The post-processing pipeline handles the model's output by:
|
||||
1. Moving data to the CPU.
|
||||
2. Unnormalizing the output features to their original scale.
|
||||
|
||||
Args:
|
||||
config: The configuration object for the PI0 policy.
|
||||
dataset_stats: A dictionary of statistics for normalization.
|
||||
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
|
||||
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
|
||||
# Add remaining processors
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
AddBatchDimensionProcessorStep(),
|
||||
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
|
||||
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim),
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name="google/paligemma-3b-pt-224",
|
||||
max_length=config.tokenizer_max_length,
|
||||
padding_side="right",
|
||||
padding="max_length",
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
output_steps: list[ProcessorStep] = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -1,214 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Test script for PI05 with video encoder (VideoPrism).
|
||||
|
||||
This script creates a dummy example to test the model with video encoding enabled.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.videovla.configuration_pi05 import PI05VideoConfig
|
||||
from lerobot.policies.videovla.modeling_pi05 import PI05VideoPolicy
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
def create_dummy_batch(
|
||||
batch_size: int = 2,
|
||||
num_frames: int = 16,
|
||||
image_size: int = 224,
|
||||
num_cameras: int = 2,
|
||||
state_dim: int = 14,
|
||||
action_dim: int = 14,
|
||||
chunk_size: int = 50,
|
||||
seq_len: int = 10,
|
||||
device: str = "cuda",
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Create a dummy batch for testing."""
|
||||
batch = {}
|
||||
|
||||
# Create image observations with temporal dimension [B, T, C, H, W]
|
||||
for i in range(num_cameras):
|
||||
key = f"{OBS_IMAGES}.camera_{i}"
|
||||
# Images in [0, 1] range
|
||||
batch[key] = torch.rand(batch_size, num_frames, 3, image_size, image_size, device=device)
|
||||
|
||||
# Create state observation [B, state_dim]
|
||||
batch[OBS_STATE] = torch.rand(batch_size, state_dim, device=device)
|
||||
|
||||
# Create language tokens and attention mask [B, seq_len]
|
||||
batch["observation.language.tokens"] = torch.randint(0, 1000, (batch_size, seq_len), device=device)
|
||||
batch["observation.language.attention_mask"] = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device)
|
||||
|
||||
# Create action targets [B, chunk_size, action_dim]
|
||||
batch[ACTION] = torch.rand(batch_size, chunk_size, action_dim, device=device)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def test_video_encoder():
|
||||
"""Test the PI05 model with video encoding enabled."""
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Configuration
|
||||
batch_size = 2
|
||||
num_frames = 16
|
||||
image_size = 224
|
||||
num_cameras = 2
|
||||
state_dim = 14
|
||||
action_dim = 14
|
||||
chunk_size = 50
|
||||
|
||||
# Create config with video encoder enabled
|
||||
print("Creating PI05VideoConfig with video encoder...")
|
||||
config = PI05VideoConfig(
|
||||
use_video_encoder=True,
|
||||
video_num_frames=num_frames,
|
||||
videoprism_model_name="MHRDYN7/videoprism-base-f16r288",
|
||||
videoprism_image_size=288,
|
||||
freeze_video_encoder=True,
|
||||
video_padding_mode="repeat",
|
||||
video_encoder_camera_key=f"{OBS_IMAGES}.camera_0", # Use first camera for video
|
||||
chunk_size=chunk_size,
|
||||
max_action_dim=32,
|
||||
max_state_dim=32,
|
||||
dtype="float32", # Use float32 for testing
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Set up input/output features
|
||||
for i in range(num_cameras):
|
||||
key = f"{OBS_IMAGES}.camera_{i}"
|
||||
config.input_features[key] = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, image_size, image_size),
|
||||
)
|
||||
|
||||
config.input_features[OBS_STATE] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(state_dim,),
|
||||
)
|
||||
|
||||
config.output_features[ACTION] = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(action_dim,),
|
||||
)
|
||||
|
||||
print(f"use_video_encoder: {config.use_video_encoder}")
|
||||
print(f"video_num_frames: {config.video_num_frames}")
|
||||
print(f"video_padding_mode: {config.video_padding_mode}")
|
||||
print(f"video_encoder_camera_key: {config.video_encoder_camera_key}")
|
||||
print(f"image_observation_delta_indices: {config.image_observation_delta_indices}")
|
||||
|
||||
# Create model
|
||||
model = PI05VideoPolicy(config)
|
||||
model.to(device)
|
||||
|
||||
# Create dummy batch
|
||||
batch = create_dummy_batch(
|
||||
batch_size=batch_size,
|
||||
num_frames=num_frames,
|
||||
image_size=image_size,
|
||||
num_cameras=num_cameras,
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
chunk_size=chunk_size,
|
||||
device=device,
|
||||
)
|
||||
|
||||
print(f"Batch keys: {list(batch.keys())}" )
|
||||
for key, value in batch.items():
|
||||
print(f"{key}: {value.shape}")
|
||||
|
||||
# Test forward pass
|
||||
model.train()
|
||||
try:
|
||||
loss, loss_dict = model.forward(batch)
|
||||
print(f"Forward pass successful!")
|
||||
print(f"Loss: {loss.item():.4f}")
|
||||
print(f"Loss dict: {loss_dict}")
|
||||
except Exception as e:
|
||||
print(f"Forward pass failed: {e}")
|
||||
raise
|
||||
|
||||
# Test inference
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
try:
|
||||
actions = model.predict_action_chunk(batch)
|
||||
print(f"Test pass, inference pass!")
|
||||
print(f"Predicted actions shape: {actions.shape}")
|
||||
except Exception as e:
|
||||
print(f"Inference failed: {e}")
|
||||
raise
|
||||
|
||||
print("All tests passed!")
|
||||
|
||||
|
||||
def test_frame_padding():
|
||||
"""Test frame padding at episode start."""
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# Create config
|
||||
config = PI05VideoConfig(
|
||||
use_video_encoder=True,
|
||||
video_num_frames=16,
|
||||
videoprism_model_name="MHRDYN7/videoprism-base-f16r288",
|
||||
freeze_video_encoder=True,
|
||||
video_padding_mode="repeat",
|
||||
chunk_size=50,
|
||||
dtype="float32",
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Set up minimal features
|
||||
config.input_features[f"{OBS_IMAGES}.camera_0"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 224, 224),
|
||||
)
|
||||
config.output_features[ACTION] = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(14,),
|
||||
)
|
||||
|
||||
# Create model
|
||||
model = PI05VideoPolicy(config)
|
||||
model.to(device)
|
||||
|
||||
# Test with fewer frames than expected (simulating episode start)
|
||||
batch = {
|
||||
f"{OBS_IMAGES}.camera_0": torch.rand(2, 5, 3, 224, 224, device=device),
|
||||
"observation.language.tokens": torch.randint(0, 1000, (2, 10), device=device),
|
||||
"observation.language.attention_mask": torch.ones(2, 10, dtype=torch.bool, device=device),
|
||||
ACTION: torch.rand(2, 50, 14, device=device),
|
||||
}
|
||||
|
||||
video_frames = model._preprocess_video(batch)
|
||||
if video_frames is not None:
|
||||
print(f"Input frames: 5")
|
||||
print(f"Output video_frames shape: {video_frames.shape}")
|
||||
print(f"Expected: [2, 16, 3, 224, 224]")
|
||||
assert video_frames.shape == (2, 16, 3, 224, 224), f"Unexpected shape: {video_frames.shape}"
|
||||
print("Frame padding test PASSED!")
|
||||
else:
|
||||
print("video_frames is None (unexpected)")
|
||||
|
||||
# Test with single frame
|
||||
batch[f"{OBS_IMAGES}.camera_0"] = torch.rand(2, 3, 224, 224, device=device) # [B, C, H, W]
|
||||
|
||||
video_frames = model._preprocess_video(batch)
|
||||
if video_frames is not None:
|
||||
print(f"Input: single frame [B, C, H, W]")
|
||||
print(f"Output video_frames shape: {video_frames.shape}")
|
||||
print(f"Expected: [2, 16, 3, 224, 224]")
|
||||
assert video_frames.shape == (2, 16, 3, 224, 224), f"Unexpected shape: {video_frames.shape}"
|
||||
print("Single frame expansion test PASSED!")
|
||||
else:
|
||||
print("video_frames is None (unexpected)")
|
||||
|
||||
print("All tests passed!")
|
||||
if __name__ == "__main__":
|
||||
# Run tests
|
||||
test_frame_padding()
|
||||
test_video_encoder()
|
||||
@@ -1,37 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_videoprism import VideoPrismConfig, VideoPrismTextConfig, VideoPrismVisionConfig
|
||||
from .modeling_videoprism import (
|
||||
VideoPrismClipModel,
|
||||
VideoPrismForVideoClassification,
|
||||
VideoPrismPreTrainedModel,
|
||||
VideoPrismTextModel,
|
||||
VideoPrismVideoModel,
|
||||
VideoPrismVisionModel,
|
||||
)
|
||||
from .video_processing_videoprism import VideoPrismVideoProcessor
|
||||
|
||||
__all__ = [
|
||||
"VideoPrismConfig",
|
||||
"VideoPrismTextConfig",
|
||||
"VideoPrismVisionConfig",
|
||||
"VideoPrismClipModel",
|
||||
"VideoPrismForVideoClassification",
|
||||
"VideoPrismPreTrainedModel",
|
||||
"VideoPrismTextModel",
|
||||
"VideoPrismVideoModel",
|
||||
"VideoPrismVisionModel",
|
||||
"VideoPrismVideoProcessor",
|
||||
]
|
||||
@@ -1,269 +0,0 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/videoprism/modular_videoprism.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_videoprism.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class VideoPrismVisionConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`VideoPrismVisionModel`]. It is used to instantiate a
|
||||
VideoPrism vision encoder according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the VideoPrism
|
||||
[google/videoprism](https://huggingface.co/google/videoprism) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
image_size (`int`, *optional*, defaults to 288):
|
||||
The size of the input image.
|
||||
num_frames (`int`, *optional*, defaults to 16):
|
||||
The number of frames in the input video.
|
||||
tubelet_size (`List[int]`, *optional*, defaults to `[1, 18, 18]`):
|
||||
The size of the tubelet patch.
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
The number of input channels.
|
||||
hidden_size (`int`, *optional*, defaults to 768):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
num_spatial_layers (`int`, *optional*, defaults to 12):
|
||||
Number of spatial transformer blocks.
|
||||
num_temporal_layers (`int`, *optional*, defaults to 4):
|
||||
Number of temporal transformer blocks.
|
||||
num_attention_heads (`int`, *optional*, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
intermediate_size (`int`, *optional*, defaults to 3072):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_python"`):
|
||||
The non-linear activation function (function or string).
|
||||
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the layer normalization layers.
|
||||
qkv_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether to add a bias to the qkv projections in attention layers.
|
||||
attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
|
||||
Softcapping constant for attention logits.
|
||||
num_auxiliary_layers (`int`, *optional*, defaults to 2):
|
||||
Number of auxiliary layers. This is used in the VideoPrismVideoModel that is a part of VideoPrismClipModel.
|
||||
apply_l2_norm (`bool`, *optional*, defaults to `True`):
|
||||
Whether to apply L2 normalization to the output. This is used in the VideoPrismVideoModel that is a part of VideoPrismClipModel.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import VideoPrismVisionConfig, VideoPrismVisionModel
|
||||
|
||||
>>> # Initializing a VideoPrismVisionConfig with default values
|
||||
>>> configuration = VideoPrismVisionConfig()
|
||||
|
||||
>>> # Initializing a VideoPrismVisionModel with the configuration
|
||||
>>> model = VideoPrismVisionModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "videoprism_vision_model"
|
||||
base_config_key = "vision_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size=288,
|
||||
num_frames=16,
|
||||
tubelet_size=[1, 18, 18],
|
||||
num_channels=3,
|
||||
hidden_size=768,
|
||||
num_spatial_layers=12,
|
||||
num_temporal_layers=4,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu_python",
|
||||
hidden_dropout_prob=0.0,
|
||||
attention_probs_dropout_prob=0.0,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-06,
|
||||
qkv_bias=True,
|
||||
attn_logit_softcapping=50.0,
|
||||
num_auxiliary_layers=2,
|
||||
apply_l2_norm=True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.hidden_size = hidden_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
|
||||
self.image_size = image_size
|
||||
self.num_frames = num_frames
|
||||
self.tubelet_size = tubelet_size
|
||||
self.num_channels = num_channels
|
||||
self.qkv_bias = qkv_bias
|
||||
self.num_spatial_layers = num_spatial_layers
|
||||
self.num_temporal_layers = num_temporal_layers
|
||||
self.attn_logit_softcapping = attn_logit_softcapping
|
||||
self.num_auxiliary_layers = num_auxiliary_layers
|
||||
self.apply_l2_norm = apply_l2_norm
|
||||
|
||||
|
||||
class VideoPrismTextConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`VideoPrismTextModel`]. It is used to instantiate a
|
||||
VideoPrism text encoder according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the VideoPrism
|
||||
[google/videoprism](https://huggingface.co/google/videoprism) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`, *optional*, defaults to 768):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
intermediate_size (`int`, *optional*, defaults to 3072):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_text_layers (`int`, *optional*, defaults to 12):
|
||||
Number of hidden layers in the text Transformer encoder.
|
||||
vocab_size (`int`, *optional*, defaults to 32000):
|
||||
Vocabulary size of the text model. Defines the number of different tokens that can be represented by the
|
||||
`input_ids` passed when calling [`VideoPrismTextModel`].
|
||||
apply_l2_norm (`bool`, *optional*, defaults to `True`):
|
||||
Whether to apply L2 normalization to the output text embeddings.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"relu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler.
|
||||
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
qkv_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether to add a bias to the query, key, and value projections in the attention layers.
|
||||
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the layer normalization layers.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
|
||||
Softcapping constant for attention logits.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import VideoPrismTextConfig, VideoPrismTextModel
|
||||
|
||||
>>> # Initializing a VideoPrismTextConfig with default values
|
||||
>>> configuration = VideoPrismTextConfig()
|
||||
|
||||
>>> # Initializing a VideoPrismTextModel (with random weights) from the configuration
|
||||
>>> model = VideoPrismTextModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "videoprism_text_model"
|
||||
base_config_key = "text_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=768,
|
||||
intermediate_size=3072,
|
||||
num_attention_heads=12,
|
||||
num_text_layers=12,
|
||||
vocab_size=32000,
|
||||
apply_l2_norm=True,
|
||||
hidden_act="relu",
|
||||
attention_probs_dropout_prob=0.0,
|
||||
qkv_bias=True,
|
||||
hidden_dropout_prob=0.0,
|
||||
layer_norm_eps=1e-06,
|
||||
initializer_range=0.02,
|
||||
attn_logit_softcapping=50.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_text_layers = num_text_layers
|
||||
self.vocab_size = vocab_size
|
||||
self.apply_l2_norm = apply_l2_norm
|
||||
self.hidden_act = hidden_act
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.qkv_bias = qkv_bias
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.initializer_range = initializer_range
|
||||
self.attn_logit_softcapping = attn_logit_softcapping
|
||||
|
||||
|
||||
class VideoPrismConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`VideoPrismModel`]. It is used to instantiate a
|
||||
VideoPrism model according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the VideoPrism
|
||||
[google/videoprism](https://huggingface.co/google/videoprism) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
text_config (`VideoPrismTextConfig`, *optional*):
|
||||
Configuration for the text model.
|
||||
vision_config (`VideoPrismVisionConfig`, *optional*):
|
||||
Configuration for the vision model.
|
||||
kwargs (*optional*):
|
||||
Dictionary of keyword arguments.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import VideoPrismConfig, VideoPrismModel
|
||||
|
||||
>>> # Initializing a VideoPrismConfig with default values
|
||||
>>> configuration = VideoPrismConfig()
|
||||
|
||||
>>> # Initializing a VideoPrismClipModel with the configuration
|
||||
>>> model = VideoPrismClipModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "videoprism"
|
||||
sub_configs = {"text_config": VideoPrismTextConfig, "vision_config": VideoPrismVisionConfig}
|
||||
|
||||
def __init__(self, text_config=None, vision_config=None, **kwargs):
|
||||
if text_config is None:
|
||||
text_config = VideoPrismTextConfig()
|
||||
logger.info("`text_config` is `None`. Initializing the `VideoPrismTextConfig` with default values.")
|
||||
elif isinstance(text_config, dict):
|
||||
text_config = VideoPrismTextConfig(**text_config)
|
||||
|
||||
if vision_config is None:
|
||||
vision_config = VideoPrismVisionConfig()
|
||||
logger.info("`vision_config` is `None`. initializing the `VideoPrismVisionConfig` with default values.")
|
||||
elif isinstance(vision_config, dict):
|
||||
vision_config = VideoPrismVisionConfig(**vision_config)
|
||||
|
||||
self.text_config = text_config
|
||||
self.vision_config = vision_config
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
__all__ = ["VideoPrismVisionConfig", "VideoPrismTextConfig", "VideoPrismConfig"]
|
||||
@@ -1,245 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# Record all the torch primitives in advance, so that we can use them without them being modified when we patch torch
|
||||
# in context managers
|
||||
TORCH_INIT_FUNCTIONS = {
|
||||
"uniform_": torch.nn.init.uniform_,
|
||||
"normal_": torch.nn.init.normal_,
|
||||
"constant_": torch.nn.init.constant_,
|
||||
"ones_": torch.nn.init.ones_,
|
||||
"zeros_": torch.nn.init.zeros_,
|
||||
"eye_": torch.nn.init.eye_,
|
||||
"dirac_": torch.nn.init.dirac_,
|
||||
"xavier_uniform_": torch.nn.init.xavier_uniform_,
|
||||
"xavier_normal_": torch.nn.init.xavier_normal_,
|
||||
"kaiming_uniform_": torch.nn.init.kaiming_uniform_,
|
||||
"kaiming_normal_": torch.nn.init.kaiming_normal_,
|
||||
"trunc_normal_": torch.nn.init.trunc_normal_,
|
||||
"orthogonal_": torch.nn.init.orthogonal_,
|
||||
"sparse_": torch.nn.init.sparse_,
|
||||
}
|
||||
|
||||
|
||||
def uniform_(
|
||||
tensor: torch.Tensor, a: float = 0.0, b: float = 1.0, generator: torch.Generator | None = None
|
||||
) -> torch.Tensor:
|
||||
if not getattr(tensor, "_is_hf_initialized", False):
|
||||
return TORCH_INIT_FUNCTIONS["uniform_"](tensor, a=a, b=b, generator=generator)
|
||||
return tensor
|
||||
|
||||
|
||||
def normal_(
|
||||
tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, generator: torch.Generator | None = None
|
||||
) -> torch.Tensor:
|
||||
if not getattr(tensor, "_is_hf_initialized", False):
|
||||
return TORCH_INIT_FUNCTIONS["normal_"](tensor, mean=mean, std=std, generator=generator)
|
||||
return tensor
|
||||
|
||||
|
||||
def constant_(tensor: torch.Tensor, val: float) -> torch.Tensor:
|
||||
if not getattr(tensor, "_is_hf_initialized", False):
|
||||
return TORCH_INIT_FUNCTIONS["constant_"](tensor, val=val)
|
||||
return tensor
|
||||
|
||||
|
||||
def ones_(tensor: torch.Tensor) -> torch.Tensor:
|
||||
if not getattr(tensor, "_is_hf_initialized", False):
|
||||
return TORCH_INIT_FUNCTIONS["ones_"](tensor)
|
||||
return tensor
|
||||
|
||||
|
||||
def zeros_(tensor: torch.Tensor) -> torch.Tensor:
|
||||
if not getattr(tensor, "_is_hf_initialized", False):
|
||||
return TORCH_INIT_FUNCTIONS["zeros_"](tensor)
|
||||
return tensor
|
||||
|
||||
|
||||
def eye_(tensor: torch.Tensor) -> torch.Tensor:
|
||||
if not getattr(tensor, "_is_hf_initialized", False):
|
||||
return TORCH_INIT_FUNCTIONS["eye_"](tensor)
|
||||
return tensor
|
||||
|
||||
|
||||
def dirac_(tensor: torch.Tensor, groups: int = 1) -> torch.Tensor:
|
||||
if not getattr(tensor, "_is_hf_initialized", False):
|
||||
return TORCH_INIT_FUNCTIONS["dirac_"](tensor, groups=groups)
|
||||
return tensor
|
||||
|
||||
|
||||
def xavier_uniform_(tensor: torch.Tensor, gain: float = 1.0, generator: torch.Generator | None = None) -> torch.Tensor:
|
||||
if not getattr(tensor, "_is_hf_initialized", False):
|
||||
return TORCH_INIT_FUNCTIONS["xavier_uniform_"](tensor, gain=gain, generator=generator)
|
||||
return tensor
|
||||
|
||||
|
||||
def xavier_normal_(tensor: torch.Tensor, gain: float = 1.0, generator: torch.Generator | None = None) -> torch.Tensor:
|
||||
if not getattr(tensor, "_is_hf_initialized", False):
|
||||
return TORCH_INIT_FUNCTIONS["xavier_normal_"](tensor, gain=gain, generator=generator)
|
||||
return tensor
|
||||
|
||||
|
||||
def kaiming_uniform_(
|
||||
tensor: torch.Tensor,
|
||||
a: float = 0,
|
||||
mode: str = "fan_in",
|
||||
nonlinearity: str = "leaky_relu",
|
||||
generator: torch.Generator | None = None,
|
||||
) -> torch.Tensor:
|
||||
if not getattr(tensor, "_is_hf_initialized", False):
|
||||
return TORCH_INIT_FUNCTIONS["kaiming_uniform_"](
|
||||
tensor, a=a, mode=mode, nonlinearity=nonlinearity, generator=generator
|
||||
)
|
||||
return tensor
|
||||
|
||||
|
||||
def kaiming_normal_(
|
||||
tensor: torch.Tensor,
|
||||
a: float = 0,
|
||||
mode: str = "fan_in",
|
||||
nonlinearity: str = "leaky_relu",
|
||||
generator: torch.Generator | None = None,
|
||||
) -> torch.Tensor:
|
||||
if not getattr(tensor, "_is_hf_initialized", False):
|
||||
return TORCH_INIT_FUNCTIONS["kaiming_normal_"](
|
||||
tensor, a=a, mode=mode, nonlinearity=nonlinearity, generator=generator
|
||||
)
|
||||
return tensor
|
||||
|
||||
|
||||
def trunc_normal_(
|
||||
tensor: torch.Tensor,
|
||||
mean: float = 0.0,
|
||||
std: float = 1.0,
|
||||
a: float = -2.0,
|
||||
b: float = 2.0,
|
||||
generator: torch.Generator | None = None,
|
||||
) -> torch.Tensor:
|
||||
if not getattr(tensor, "_is_hf_initialized", False):
|
||||
return TORCH_INIT_FUNCTIONS["trunc_normal_"](tensor, mean=mean, std=std, a=a, b=b, generator=generator)
|
||||
return tensor
|
||||
|
||||
|
||||
def orthogonal_(
|
||||
tensor: torch.Tensor,
|
||||
gain: float = 1,
|
||||
generator: torch.Generator | None = None,
|
||||
) -> torch.Tensor:
|
||||
if not getattr(tensor, "_is_hf_initialized", False):
|
||||
return TORCH_INIT_FUNCTIONS["orthogonal_"](tensor, gain=gain, generator=generator)
|
||||
return tensor
|
||||
|
||||
|
||||
def sparse_(
|
||||
tensor: torch.Tensor, sparsity: float, std: float = 0.01, generator: torch.Generator | None = None
|
||||
) -> torch.Tensor:
|
||||
if not getattr(tensor, "_is_hf_initialized", False):
|
||||
return TORCH_INIT_FUNCTIONS["sparse_"](tensor, sparsity=sparsity, std=std, generator=generator)
|
||||
return tensor
|
||||
|
||||
|
||||
def copy_(tensor: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
|
||||
if not getattr(tensor, "_is_hf_initialized", False):
|
||||
with torch.no_grad():
|
||||
return tensor.copy_(other)
|
||||
return tensor
|
||||
|
||||
|
||||
# Here, we need to check several modules imported, and hot patch all of them, as sometimes torch does
|
||||
# something like `from torch.nn.init import xavier_uniform_` in their internals (e.g in torch.nn.modules.activations,
|
||||
# where MultiHeadAttention lives), so the function name is binded at import time and just doing
|
||||
# `setattr(torch.nn.init, name, globals()[name])` is thus not enough
|
||||
# The following list should be enough for all torch versions we work with
|
||||
TORCH_MODULES_TO_PATCH = (
|
||||
"torch.nn.init",
|
||||
"torch.nn.modules.activation",
|
||||
"torch.nn.modules.transformer",
|
||||
"torch.nn.modules.linear",
|
||||
"torch.nn.modules.loss",
|
||||
"torch.nn.modules.batchnorm",
|
||||
"torch.nn.modules.conv",
|
||||
"torch.nn.modules.normalization",
|
||||
"torch.nn.modules.rnn",
|
||||
"torch.nn.modules.sparse",
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def guard_torch_init_functions():
|
||||
"""
|
||||
Guard the `torch.nn.init` primitive functions to behave exactly like the functions in this file, i.e. be
|
||||
protected against the `_is_hf_initialized` flag to avoid re-init if the param was already loaded.
|
||||
|
||||
Usually, all models are using the init from `transformers` which are already guarded, but just to make extra sure
|
||||
and for remote code, we also use this context manager.
|
||||
"""
|
||||
originals = defaultdict(dict)
|
||||
try:
|
||||
# Replace all torch funcs by the ones in this file
|
||||
for module_name in TORCH_MODULES_TO_PATCH:
|
||||
if module_name in sys.modules:
|
||||
module = sys.modules[module_name]
|
||||
for func_name in TORCH_INIT_FUNCTIONS.keys():
|
||||
if hasattr(module, func_name):
|
||||
originals[module][func_name] = getattr(module, func_name)
|
||||
setattr(module, func_name, globals()[func_name])
|
||||
yield
|
||||
finally:
|
||||
# Set back the original functions on all modules
|
||||
for module, functions in originals.items():
|
||||
for func_name, func in functions.items():
|
||||
setattr(module, func_name, func)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def no_init_weights():
|
||||
"""
|
||||
Disable weight initialization both at the torch-level, and at the transformers-level (`init_weights`).
|
||||
This is used to speed-up initializing an empty model with deepspeed, as we do not initialize the model on meta device
|
||||
with deepspeed, but we still don't need to run expensive weight initializations as we are loading params afterwards.
|
||||
"""
|
||||
from .modeling_utils import PreTrainedModel
|
||||
|
||||
def empty_func(*args, **kwargs):
|
||||
pass
|
||||
|
||||
originals = defaultdict(dict)
|
||||
try:
|
||||
# Replace all torch funcs by empty ones
|
||||
for module_name in TORCH_MODULES_TO_PATCH:
|
||||
if module_name in sys.modules:
|
||||
module = sys.modules[module_name]
|
||||
for func_name in TORCH_INIT_FUNCTIONS.keys():
|
||||
if hasattr(module, func_name):
|
||||
originals[module][func_name] = getattr(module, func_name)
|
||||
setattr(module, func_name, empty_func)
|
||||
|
||||
# Also patch our own `init_weights`
|
||||
original_init_weights = PreTrainedModel.init_weights
|
||||
PreTrainedModel.init_weights = empty_func
|
||||
|
||||
yield
|
||||
finally:
|
||||
# Set back the original torch functions on all modules
|
||||
for module, functions in originals.items():
|
||||
for func_name, func in functions.items():
|
||||
setattr(module, func_name, func)
|
||||
# Set back `init_weights`
|
||||
PreTrainedModel.init_weights = original_init_weights
|
||||
@@ -1,994 +0,0 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/videoprism/modular_videoprism.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_videoprism.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
import math
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.init import _calculate_fan_in_and_fan_out
|
||||
|
||||
from . import initialization as init
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.masking_utils import create_causal_mask
|
||||
from transformers.modeling_layers import GradientCheckpointingLayer
|
||||
from transformers.modeling_outputs import BaseModelOutput, ImageClassifierOutput
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from transformers.file_utils import ModelOutput
|
||||
|
||||
from .configuration_videoprism import VideoPrismConfig, VideoPrismTextConfig, VideoPrismVisionConfig
|
||||
|
||||
def torch_int(x):
|
||||
"""
|
||||
Casts an input to a torch int64 tensor if we are in a tracing context, otherwise to a Python int.
|
||||
"""
|
||||
if not torch.is_available():
|
||||
return int(x)
|
||||
|
||||
return x.to(torch.int64) if torch.jit.is_tracing() and isinstance(x, torch.Tensor) else int(x)
|
||||
|
||||
@dataclass
|
||||
class BaseModelOutputWithSpatialAndTemporalStates(ModelOutput):
|
||||
"""
|
||||
Base class for model outputs that include spatial and temporal states.
|
||||
|
||||
Args:
|
||||
last_hidden_state (Optional[torch.FloatTensor]):
|
||||
The last hidden state of the model, typically of shape
|
||||
(batch_size, num_patches * num_frames, hidden_size).
|
||||
|
||||
temporal_hidden_state (Optional[torch.FloatTensor]):
|
||||
The last hidden_state of the temporal encoder, typically of shape
|
||||
(batch_size * num_patches, num_frames, hidden_size).
|
||||
|
||||
spatial_hidden_state (Optional[torch.FloatTensor]):
|
||||
The last hidden_state of the spatial encoder, typically of shape
|
||||
(batch_size * num_frames, num_patches, hidden_size).
|
||||
"""
|
||||
|
||||
last_hidden_state: torch.FloatTensor | None = None
|
||||
temporal_hidden_state: torch.FloatTensor | None = None
|
||||
spatial_hidden_state: torch.FloatTensor | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoPrismClipOutput(ModelOutput):
|
||||
"""
|
||||
Base class for VideoPrismClip model outputs.
|
||||
"""
|
||||
|
||||
logits_per_video: torch.FloatTensor | None = None
|
||||
logits_per_text: torch.FloatTensor | None = None
|
||||
video_embeds: torch.FloatTensor | None = None
|
||||
text_embeds: torch.FloatTensor | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoPrismVideoOutput(ModelOutput):
|
||||
"""
|
||||
Base class for VideoPrismVideo model outputs.
|
||||
"""
|
||||
|
||||
video_last_hidden_state: torch.FloatTensor | None = None
|
||||
auxiliary_output: torch.FloatTensor | None = None
|
||||
attention_pooling_output: torch.FloatTensor | None = None
|
||||
|
||||
|
||||
class VideoPrismTubeletEmbeddings(nn.Module):
|
||||
"""
|
||||
Construct VideoPrism Tubelet embeddings.
|
||||
|
||||
This module turns a batch of videos of shape (batch_size, num_frames, num_channels, height, width) into a tensor of
|
||||
shape (batch_size, seq_len, hidden_size) to be consumed by a Transformer encoder.
|
||||
|
||||
The seq_len (the number of patches) equals (number of frames // tubelet_size[0]) * (height // tubelet_size[1]) *
|
||||
(width // tubelet_size[2]).
|
||||
"""
|
||||
|
||||
def __init__(self, config: VideoPrismVisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_frames = config.num_frames
|
||||
self.image_size = (
|
||||
config.image_size
|
||||
if isinstance(self.config.image_size, tuple)
|
||||
else (self.config.image_size, self.config.image_size)
|
||||
)
|
||||
self.patch_size = config.tubelet_size
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
self.projection = nn.Conv3d(
|
||||
config.num_channels, config.hidden_size, kernel_size=config.tubelet_size, stride=config.tubelet_size
|
||||
)
|
||||
self.pos_emb_shape = [self.image_size[0] // self.patch_size[1], self.image_size[1] // self.patch_size[2]]
|
||||
self.num_patches = self.pos_emb_shape[0] * self.pos_emb_shape[1]
|
||||
|
||||
def forward(self, pixel_values_videos: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
|
||||
batch_size, num_frames, num_channels, height, width = pixel_values_videos.shape
|
||||
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
|
||||
raise ValueError(
|
||||
f"Image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]}). Set interpolate_pos_encoding=True to automatically resize the model position embeddings."
|
||||
)
|
||||
# permute to (batch_size, num_channels, num_frames, height, width)
|
||||
pixel_values_videos = pixel_values_videos.permute(0, 2, 1, 3, 4)
|
||||
|
||||
hidden_states = self.projection(pixel_values_videos)
|
||||
# flatten the spatial part and permute to (B, T, num_patches, dim)
|
||||
hidden_states = hidden_states.flatten(3).permute(0, 2, 3, 1)
|
||||
# combine batch and time dimension
|
||||
batch_size, num_frames, num_patches, hidden_size = hidden_states.shape
|
||||
hidden_states = hidden_states.reshape(batch_size * num_frames, num_patches, hidden_size)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class VideoPrismSpatialEmbeddings(nn.Module):
|
||||
"""
|
||||
VideoPrism Spatial Embeddings.
|
||||
|
||||
Creates embeddings from a video using VideoPrismSpatialTubeletEmbeddings and adds positional embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, config: VideoPrismVisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.patch_embeddings = VideoPrismTubeletEmbeddings(config)
|
||||
self.position_embeddings = nn.Parameter(torch.zeros(1, self.patch_embeddings.num_patches, config.hidden_size))
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.patch_size = config.tubelet_size[1:]
|
||||
self.tubelet_size = config.tubelet_size
|
||||
|
||||
# Adapted from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
||||
images. This method is also adapted to support torch.jit tracing.
|
||||
|
||||
Adapted from:
|
||||
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
||||
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
||||
"""
|
||||
|
||||
num_patches = embeddings.shape[1]
|
||||
num_positions = self.position_embeddings.shape[1]
|
||||
|
||||
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
||||
return self.position_embeddings
|
||||
|
||||
dim = embeddings.shape[-1]
|
||||
|
||||
num_row_patches = height // self.patch_size[0]
|
||||
num_col_patches = width // self.patch_size[1]
|
||||
|
||||
sqrt_num_positions = torch_int(num_positions**0.5)
|
||||
patch_pos_embed = self.position_embeddings.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
size=(num_row_patches, num_col_patches),
|
||||
mode="bilinear",
|
||||
antialias=True,
|
||||
)
|
||||
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return patch_pos_embed
|
||||
|
||||
def forward(
|
||||
self, pixel_values_videos: torch.Tensor, interpolate_pos_encoding: bool | None = False
|
||||
) -> torch.Tensor:
|
||||
b, t, c, h, w = pixel_values_videos.shape
|
||||
assert h == w, "Input image height and width must be the same"
|
||||
embeddings = self.patch_embeddings(pixel_values_videos, interpolate_pos_encoding)
|
||||
|
||||
# add positional encoding to each token
|
||||
if interpolate_pos_encoding:
|
||||
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, h, w)
|
||||
else:
|
||||
embeddings = embeddings + self.position_embeddings
|
||||
|
||||
embeddings = self.dropout(embeddings)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
class VideoPrismTemporalEmbeddings(nn.Module):
|
||||
"""
|
||||
VideoPrism Temporal Embeddings.
|
||||
|
||||
Receives embeddings from spatial encoder, reshapes the hidden state to
|
||||
(batch_size * num_patches, num_frames, hidden_size) and adds positional embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, config: VideoPrismVisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.position_embeddings = nn.Parameter(torch.zeros(1, self.config.num_frames, config.hidden_size))
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
# Adapted from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
||||
images. This method is also adapted to support torch.jit tracing.
|
||||
|
||||
Adapted from:
|
||||
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
||||
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
||||
"""
|
||||
target_emb_length = embeddings.shape[1]
|
||||
source_emb_length = self.position_embeddings.shape[1]
|
||||
|
||||
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||
if not torch.jit.is_tracing() and target_emb_length == source_emb_length:
|
||||
return self.position_embeddings
|
||||
|
||||
source_emb = self.position_embeddings
|
||||
dim = embeddings.shape[-1]
|
||||
source_emb = source_emb.unsqueeze(1)
|
||||
source_emb = nn.functional.interpolate(
|
||||
source_emb,
|
||||
size=(target_emb_length, dim),
|
||||
mode="bilinear",
|
||||
antialias=True,
|
||||
)
|
||||
|
||||
return source_emb.squeeze(1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values_videos: torch.Tensor,
|
||||
input_shape: torch.Size,
|
||||
interpolate_pos_encoding: bool | None = False,
|
||||
) -> torch.Tensor:
|
||||
if input_shape is not None:
|
||||
b, t, c, h, w = input_shape
|
||||
_, features, dim = pixel_values_videos.shape
|
||||
hidden_states = pixel_values_videos.view(b, t, features, dim)
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3)
|
||||
embeddings = hidden_states.reshape(b * features, t, dim)
|
||||
|
||||
# add positional encoding to each token
|
||||
if interpolate_pos_encoding:
|
||||
embeddings = embeddings + self.interpolate_pos_encoding(embeddings)
|
||||
else:
|
||||
embeddings = embeddings + self.position_embeddings
|
||||
embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
softcap: float | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
|
||||
if softcap is not None:
|
||||
attn_weights = attn_weights / softcap
|
||||
attn_weights = torch.tanh(attn_weights)
|
||||
attn_weights = attn_weights * softcap
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask.expand(*attn_weights.shape)
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class VideoPrismSelfAttention(nn.Module):
|
||||
def __init__(self, config: VideoPrismVisionConfig | VideoPrismTextConfig):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
|
||||
f"heads {config.num_attention_heads}."
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.dropout_prob = config.attention_probs_dropout_prob
|
||||
self.scale = self.attention_head_size**-0.5
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = hidden_states.shape[0]
|
||||
new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size
|
||||
query = self.query(hidden_states).view(*new_shape).transpose(1, 2)
|
||||
key = self.key(hidden_states).view(*new_shape).transpose(1, 2)
|
||||
value = self.value(hidden_states).view(*new_shape).transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
context_layer, attention_probs = attention_interface(
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attention_mask,
|
||||
scaling=self.scale,
|
||||
dropout=0.0 if not self.training else self.dropout_prob,
|
||||
softcap=self.config.attn_logit_softcapping,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.reshape(new_context_layer_shape)
|
||||
|
||||
return (context_layer, attention_probs)
|
||||
|
||||
|
||||
class VideoPrismSelfOutput(nn.Module):
|
||||
"""
|
||||
The residual connection is defined in VideoPrismLayer instead of here (as is the case with other models), due to the
|
||||
layernorm applied before each block.
|
||||
"""
|
||||
|
||||
def __init__(self, config: VideoPrismConfig):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class VideoPrismAttention(nn.Module):
|
||||
def __init__(self, config: VideoPrismConfig):
|
||||
super().__init__()
|
||||
self.attention = VideoPrismSelfAttention(config)
|
||||
self.output = VideoPrismSelfOutput(config)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, **kwargs
|
||||
) -> torch.Tensor:
|
||||
self_attn_output, _ = self.attention(hidden_states, attention_mask, **kwargs)
|
||||
output = self.output(self_attn_output, hidden_states)
|
||||
return output
|
||||
|
||||
|
||||
class VideoPrismLayerNorm(nn.LayerNorm):
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
return F.layer_norm(hidden_states, self.normalized_shape, self.weight + 1, self.bias, self.eps)
|
||||
|
||||
|
||||
class VideoPrismIntermediate(nn.Module):
|
||||
def __init__(self, config: VideoPrismConfig):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||||
else:
|
||||
self.intermediate_act_fn = config.hidden_act
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class VideoPrismOutput(nn.Module):
|
||||
def __init__(self, config: VideoPrismConfig):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = hidden_states + input_tensor
|
||||
return hidden_states
|
||||
|
||||
|
||||
class VideoPrismLayer(GradientCheckpointingLayer):
|
||||
"""This corresponds to the EncoderBlock class in the scenic/videoprism implementation."""
|
||||
|
||||
def __init__(self, config: VideoPrismVisionConfig | VideoPrismTextConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.attention = VideoPrismAttention(config)
|
||||
self.intermediate = VideoPrismIntermediate(config)
|
||||
self.output = VideoPrismOutput(config)
|
||||
self.layernorm_before = VideoPrismLayerNorm(self.config.hidden_size, eps=self.config.layer_norm_eps)
|
||||
self.layernorm_after = VideoPrismLayerNorm(self.config.hidden_size, eps=self.config.layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states_norm = self.layernorm_before(hidden_states)
|
||||
attention_output = self.attention(hidden_states_norm, attention_mask, **kwargs)
|
||||
|
||||
# first residual connection
|
||||
hidden_states = attention_output + hidden_states
|
||||
|
||||
# in VideoPrism, layernorm is also applied after self-attention
|
||||
layer_output = self.layernorm_after(hidden_states)
|
||||
layer_output = self.intermediate(layer_output)
|
||||
|
||||
# second residual connection is done here
|
||||
layer_output = self.output(layer_output, hidden_states)
|
||||
|
||||
return layer_output
|
||||
|
||||
|
||||
class VideoPrismSpatialEncoder(nn.Module):
|
||||
def __init__(self, config: VideoPrismVisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([VideoPrismLayer(config) for _ in range(config.num_spatial_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> BaseModelOutput:
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
hidden_states = layer_module(hidden_states)
|
||||
|
||||
return BaseModelOutput(last_hidden_state=hidden_states)
|
||||
|
||||
|
||||
class VideoPrismTemporalEncoder(nn.Module):
|
||||
def __init__(self, config: VideoPrismVisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([VideoPrismLayer(config) for _ in range(config.num_temporal_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> BaseModelOutput:
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
hidden_states = layer_module(hidden_states)
|
||||
|
||||
return BaseModelOutput(last_hidden_state=hidden_states)
|
||||
|
||||
|
||||
class VideoPrismAuxiliaryEncoder(nn.Module):
|
||||
def __init__(self, config: VideoPrismVisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([VideoPrismLayer(self.config) for _ in range(config.num_auxiliary_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> BaseModelOutput:
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
hidden_states = layer_module(hidden_states, attention_mask, **kwargs)
|
||||
|
||||
return BaseModelOutput(last_hidden_state=hidden_states)
|
||||
|
||||
|
||||
class VideoPrismTextEncoder(nn.Module):
|
||||
def __init__(self, config: VideoPrismTextConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([VideoPrismLayer(config) for _ in range(config.num_text_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> BaseModelOutput:
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
hidden_states = layer_module(hidden_states, attention_mask, **kwargs)
|
||||
|
||||
return BaseModelOutput(last_hidden_state=hidden_states)
|
||||
|
||||
|
||||
def variance_scaling_(tensor, mode="fan_in", distribution="normal"):
|
||||
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
||||
if mode == "fan_in":
|
||||
denom = fan_in
|
||||
elif mode == "fan_out":
|
||||
denom = fan_out
|
||||
elif mode == "fan_avg":
|
||||
denom = (fan_in + fan_out) / 2
|
||||
|
||||
variance = 1.0 / denom
|
||||
|
||||
if distribution == "truncated_normal":
|
||||
init.trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
|
||||
elif distribution == "normal":
|
||||
init.normal_(tensor, std=math.sqrt(variance))
|
||||
elif distribution == "uniform":
|
||||
bound = math.sqrt(3 * variance)
|
||||
init.uniform_(tensor, -bound, bound)
|
||||
else:
|
||||
raise ValueError(f"invalid distribution {distribution}")
|
||||
|
||||
|
||||
def lecun_normal_(tensor):
|
||||
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
|
||||
|
||||
|
||||
class VideoPrismPreTrainedModel(PreTrainedModel):
|
||||
config_class = VideoPrismConfig
|
||||
config: VideoPrismConfig
|
||||
base_model_prefix = "videoprism"
|
||||
main_input_name = "pixel_values_videos"
|
||||
input_modalities = ("video", "text")
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = [
|
||||
"VideoPrismSpatialEmbeddings",
|
||||
"VideoPrismTemporalEmbeddings",
|
||||
"VideoPrismSpatialEncoder",
|
||||
"VideoPrismTemporalEncoder",
|
||||
"VideoPrismAuxiliaryEncoder",
|
||||
"VideoPrismTextEncoder",
|
||||
"VideoPrismMultiheadAttentionPoolingHead",
|
||||
]
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn = True
|
||||
_supports_attention_backend = True
|
||||
_supports_flex_attention = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, (nn.Linear, nn.Conv3d)):
|
||||
lecun_normal_(module.weight)
|
||||
init.zeros_(module.bias)
|
||||
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
init.zeros_(module.bias)
|
||||
init.ones_(module.weight)
|
||||
|
||||
|
||||
class VideoPrismVisionModel(VideoPrismPreTrainedModel):
|
||||
config_class = VideoPrismVisionConfig
|
||||
config: VideoPrismVisionConfig
|
||||
|
||||
def __init__(self, config: VideoPrismVisionConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.layernorm1 = VideoPrismLayerNorm(self.config.hidden_size, eps=self.config.layer_norm_eps)
|
||||
self.layernorm2 = VideoPrismLayerNorm(self.config.hidden_size, eps=self.config.layer_norm_eps)
|
||||
self.spatial_embeddings = VideoPrismSpatialEmbeddings(self.config)
|
||||
self.temporal_embeddings = VideoPrismTemporalEmbeddings(self.config)
|
||||
self.spatial_encoder = VideoPrismSpatialEncoder(self.config)
|
||||
self.temporal_encoder = VideoPrismTemporalEncoder(self.config)
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.spatial_embeddings.patch_embeddings
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values_videos: torch.FloatTensor | None = None,
|
||||
interpolate_pos_encoding: bool | None = False,
|
||||
**kwargs,
|
||||
) -> BaseModelOutputWithSpatialAndTemporalStates:
|
||||
r"""
|
||||
Args:
|
||||
pixel_values_videos (`torch.FloatTensor`):
|
||||
Pixel values of the video frames of shape (batch_size, num_frames, num_channels, height, width).
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate positional encodings to match input size.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import VideoPrismVideoProcessor, VideoPrismVisionModel
|
||||
>>> import torch
|
||||
|
||||
>>> processor = VideoPrismVideoProcessor.from_pretrained("google/videoprism")
|
||||
>>> model = VideoPrismVisionModel.from_pretrained("google/videoprism")
|
||||
|
||||
>>> video = "sample_video.mp4"
|
||||
>>> inputs = processor(videos=video)
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(**inputs)
|
||||
... features = outputs.last_hidden_state
|
||||
```
|
||||
"""
|
||||
if pixel_values_videos is None:
|
||||
raise ValueError("You have to specify pixel_values_videos")
|
||||
|
||||
input_shape = pixel_values_videos.shape
|
||||
spatial_embeds = self.spatial_embeddings(pixel_values_videos, interpolate_pos_encoding)
|
||||
spatial_encoder_outputs: BaseModelOutput = self.spatial_encoder(hidden_states=spatial_embeds, **kwargs)
|
||||
# shape of spatial_sequence_output is (B * num_frames, num_patches, dim)
|
||||
spatial_sequence_output = spatial_encoder_outputs.last_hidden_state
|
||||
features = self.layernorm1(spatial_sequence_output)
|
||||
|
||||
temporal_embeds = self.temporal_embeddings(features, input_shape, interpolate_pos_encoding)
|
||||
temporal_encoder_outputs: BaseModelOutput = self.temporal_encoder(hidden_states=temporal_embeds, **kwargs)
|
||||
# shape of temporal_sequence_output is (B * num_patches, num_frames, dim)
|
||||
temporal_sequence_output = temporal_encoder_outputs.last_hidden_state
|
||||
features = self.layernorm2(temporal_sequence_output)
|
||||
_, num_frames, dim = features.shape
|
||||
features = features.view(input_shape[0], -1, num_frames, dim).permute(0, 2, 1, 3).contiguous()
|
||||
_, num_frames, num_patches, dim = features.shape
|
||||
features = features.view(input_shape[0], num_frames * num_patches, -1)
|
||||
|
||||
return BaseModelOutputWithSpatialAndTemporalStates(
|
||||
last_hidden_state=features,
|
||||
temporal_hidden_state=temporal_sequence_output,
|
||||
spatial_hidden_state=spatial_sequence_output,
|
||||
)
|
||||
|
||||
|
||||
class VideoPrismMultiheadAttentionPoolingHead(nn.Module):
|
||||
def __init__(self, config: VideoPrismVisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_attention_heads = self.config.num_attention_heads
|
||||
self.attention_head_size = int(self.config.intermediate_size / self.config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.dropout_prob = self.config.attention_probs_dropout_prob
|
||||
# PerDimScale
|
||||
self.dim = int(self.config.intermediate_size / self.config.num_attention_heads)
|
||||
self.per_dim_scale = nn.Parameter(torch.zeros(self.dim))
|
||||
r_softplus_0 = 1.442695041
|
||||
scale = torch.tensor(r_softplus_0 / (self.dim**0.5))
|
||||
softplus = nn.functional.softplus(self.per_dim_scale)
|
||||
scale = scale * softplus
|
||||
self.register_buffer("scale", scale)
|
||||
|
||||
self.pooling_attention_query = nn.Parameter(torch.zeros(1, 1, self.config.hidden_size))
|
||||
self.query = nn.Linear(self.config.hidden_size, self.config.intermediate_size, bias=self.config.qkv_bias)
|
||||
self.key = nn.Linear(self.config.hidden_size, self.config.intermediate_size, bias=self.config.qkv_bias)
|
||||
self.value = nn.Linear(self.config.hidden_size, self.config.intermediate_size, bias=self.config.qkv_bias)
|
||||
self.projection = nn.Linear(self.config.intermediate_size, self.config.hidden_size, bias=self.config.qkv_bias)
|
||||
self.layernorm = VideoPrismLayerNorm(self.config.hidden_size, eps=self.config.layer_norm_eps)
|
||||
self.dim = int(self.config.intermediate_size / self.config.num_attention_heads)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
attention_mask: torch.LongTensor | None = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
batch_size, seq_length, hidden_size = hidden_states.shape
|
||||
query = self.pooling_attention_query.expand(batch_size, -1, -1)
|
||||
query_layer = (
|
||||
self.query(query).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
|
||||
)
|
||||
query_layer = query_layer * self.scale.expand(*query_layer.shape)
|
||||
|
||||
key_layer = (
|
||||
self.key(hidden_states)
|
||||
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
value_layer = (
|
||||
self.value(hidden_states)
|
||||
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
context_layer, attention_probs = attention_interface(
|
||||
self,
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attention_mask,
|
||||
scaling=1.0,
|
||||
dropout=0.0 if not self.training else self.dropout_prob,
|
||||
softcap=None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.reshape(new_context_layer_shape)
|
||||
outputs = self.projection(context_layer)
|
||||
outputs = self.layernorm(outputs)
|
||||
return (outputs, attention_probs)
|
||||
|
||||
|
||||
def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
|
||||
"""This function is intended to align with the l2norm implementation in the FLA library."""
|
||||
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
|
||||
return x * inv_norm
|
||||
|
||||
|
||||
class VideoPrismTextModel(VideoPrismPreTrainedModel):
|
||||
config_class = VideoPrismTextConfig
|
||||
config: VideoPrismTextConfig
|
||||
|
||||
def __init__(self, config: VideoPrismTextConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.text_encoder = VideoPrismTextEncoder(self.config)
|
||||
self.token_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.cls_emb = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
||||
self.layernorm = VideoPrismLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.normalize = config.apply_l2_norm
|
||||
self.post_init()
|
||||
|
||||
def create_sinusoidal_positions(self, num_pos: int, dim: int) -> torch.Tensor:
|
||||
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / (dim - 2)))
|
||||
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float()
|
||||
return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> BaseModelOutput:
|
||||
r"""
|
||||
Args:
|
||||
input_ids (`torch.Tensor`):
|
||||
Input token IDs.
|
||||
attention_mask (`torch.Tensor`, *optional*):
|
||||
Attention mask to avoid performing attention on padding token indices.
|
||||
"""
|
||||
batch_size, seq_length = input_ids.shape
|
||||
hidden_states = self.token_embeddings(input_ids)
|
||||
hidden_states = hidden_states * (self.config.hidden_size**0.5)
|
||||
|
||||
cls_padding = torch.ones(batch_size, 1)
|
||||
input_ids = torch.cat((input_ids, cls_padding), dim=1)
|
||||
attention_mask = torch.cat((attention_mask, cls_padding), dim=1) if attention_mask is not None else None
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=torch.arange(hidden_states.shape[1] + 1, device=hidden_states.device),
|
||||
past_key_values=None,
|
||||
)
|
||||
|
||||
features = hidden_states + self.create_sinusoidal_positions(seq_length, self.config.hidden_size)
|
||||
cls_emb = self.cls_emb * (self.config.hidden_size**0.5)
|
||||
cls_emb = cls_emb.expand(features.shape[0], -1, -1)
|
||||
features = torch.cat((features, cls_emb), dim=1)
|
||||
text_encoder_output = self.text_encoder(features, attention_mask)
|
||||
features = text_encoder_output.last_hidden_state
|
||||
features = self.layernorm(features)
|
||||
text_embeddings = features[:, -1]
|
||||
|
||||
if self.normalize:
|
||||
text_embeddings = l2norm(text_embeddings, dim=-1)
|
||||
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=text_embeddings,
|
||||
)
|
||||
|
||||
|
||||
|
||||
class VideoPrismVideoModel(VideoPrismPreTrainedModel):
|
||||
config_class = VideoPrismVisionConfig
|
||||
config: VideoPrismVisionConfig
|
||||
|
||||
def __init__(self, config: VideoPrismVisionConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.backbone = VideoPrismVisionModel(self.config)
|
||||
self.auxiliary_encoder = VideoPrismAuxiliaryEncoder(self.config)
|
||||
self.contrastive_vision_pooler = VideoPrismMultiheadAttentionPoolingHead(self.config)
|
||||
self.normalize = self.config.apply_l2_norm
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.backbone.spatial_embeddings.patch_embeddings
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values_videos: torch.FloatTensor,
|
||||
interpolate_pos_encoding: bool | None = False,
|
||||
**kwargs,
|
||||
) -> VideoPrismVideoOutput:
|
||||
r"""
|
||||
Args:
|
||||
pixel_values_videos (`torch.FloatTensor`):
|
||||
Pixel values of the video frames.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate positional encodings to match input size.
|
||||
"""
|
||||
backbone_outputs = self.backbone(
|
||||
pixel_values_videos=pixel_values_videos, interpolate_pos_encoding=interpolate_pos_encoding, **kwargs
|
||||
)
|
||||
video_features = backbone_outputs.last_hidden_state
|
||||
auxiliary_output = self.auxiliary_encoder(video_features)
|
||||
auxiliary_output_features = auxiliary_output.last_hidden_state
|
||||
contrastive_vision_pooler_output = self.contrastive_vision_pooler(auxiliary_output_features, **kwargs)
|
||||
video_embeddings = contrastive_vision_pooler_output[0]
|
||||
if self.normalize:
|
||||
video_embeddings = l2norm(video_embeddings, dim=-1)
|
||||
|
||||
return VideoPrismVideoOutput(
|
||||
video_last_hidden_state=video_embeddings,
|
||||
auxiliary_output=auxiliary_output,
|
||||
attention_pooling_output=contrastive_vision_pooler_output,
|
||||
)
|
||||
|
||||
|
||||
class VideoPrismClipModel(VideoPrismPreTrainedModel):
|
||||
config_class = VideoPrismConfig
|
||||
|
||||
def __init__(self, config: VideoPrismConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.vision_config = config.vision_config
|
||||
self.text_config = config.text_config
|
||||
self.video_model = VideoPrismVideoModel(self.vision_config)
|
||||
self.text_model = VideoPrismTextModel(self.text_config)
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values_videos: torch.FloatTensor,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
interpolate_pos_encoding: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
**kwargs,
|
||||
) -> VideoPrismClipOutput:
|
||||
r"""
|
||||
Args:
|
||||
pixel_values_videos (`torch.FloatTensor`):
|
||||
Pixel values of the video frames.
|
||||
input_ids (`torch.Tensor`):
|
||||
Input token IDs for text.
|
||||
attention_mask (`torch.Tensor`, *optional*):
|
||||
Attention mask for text inputs.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate positional encodings.
|
||||
temperature (`float`, *optional*):
|
||||
Temperature parameter for scaling similarity scores.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import VideoPrismProcessor, VideoPrismClipModel
|
||||
>>> import torch
|
||||
|
||||
>>> processor = VideoPrismProcessor.from_pretrained("google/videoprism")
|
||||
>>> model = VideoPrismClipModel.from_pretrained("google/videoprism")
|
||||
|
||||
>>> video = "sample_video.mp4"
|
||||
>>> texts = ["a dog", "a cat"]
|
||||
>>> inputs = processor(videos=video, texts=texts, return_tensors="pt", padding=True)
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(**inputs)
|
||||
... logits_per_video = outputs.logits_per_video
|
||||
```
|
||||
"""
|
||||
video_model_outputs = self.video_model(
|
||||
pixel_values_videos=pixel_values_videos, interpolate_pos_encoding=interpolate_pos_encoding, **kwargs
|
||||
)
|
||||
text_model_outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
|
||||
|
||||
video_embeddings = video_model_outputs.video_last_hidden_state
|
||||
text_embeddings = text_model_outputs.last_hidden_state
|
||||
emb_dim = video_embeddings[0].shape[-1]
|
||||
assert emb_dim == text_embeddings[0].shape[-1]
|
||||
|
||||
video_embeds = video_embeddings.reshape(-1, emb_dim)
|
||||
text_embeds = text_embeddings.reshape(-1, emb_dim)
|
||||
similarity_matrix = torch.matmul(video_embeds, text_embeds.T)
|
||||
|
||||
if temperature is not None:
|
||||
similarity_matrix /= temperature
|
||||
|
||||
logits_per_video = torch.exp(similarity_matrix)
|
||||
logits_per_text = logits_per_video.T
|
||||
logits_per_video = logits_per_video / torch.sum(logits_per_video, dim=0, keepdims=True)
|
||||
logits_per_text = logits_per_text / torch.sum(logits_per_text, dim=0, keepdims=True)
|
||||
|
||||
return VideoPrismClipOutput(
|
||||
logits_per_video=logits_per_video,
|
||||
logits_per_text=logits_per_text,
|
||||
video_embeds=video_embeds,
|
||||
text_embeds=text_embeds,
|
||||
)
|
||||
|
||||
|
||||
|
||||
class VideoPrismForVideoClassification(VideoPrismPreTrainedModel):
|
||||
config_class = VideoPrismVisionConfig
|
||||
config: VideoPrismVisionConfig
|
||||
|
||||
def __init__(self, config: VideoPrismVisionConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.encoder = VideoPrismVisionModel(self.config)
|
||||
self.contrastive_vision_pooler = VideoPrismMultiheadAttentionPoolingHead(self.config)
|
||||
self.classifier = nn.Linear(self.config.hidden_size, self.config.num_labels)
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.encoder.spatial_embeddings.patch_embeddings
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values_videos: torch.FloatTensor,
|
||||
labels: torch.LongTensor | None = None,
|
||||
interpolate_pos_encoding: bool | None = False,
|
||||
**kwargs,
|
||||
) -> ImageClassifierOutput:
|
||||
r"""
|
||||
Args:
|
||||
pixel_values_videos (`torch.FloatTensor`):
|
||||
Pixel values of the video frames.
|
||||
labels (`torch.LongTensor`, *optional*):
|
||||
Video classification labels.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate positional encodings.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import VideoPrismVideoProcessor, VideoPrismForVideoClassification
|
||||
>>> import torch
|
||||
|
||||
>>> processor = VideoPrismVideoProcessor("google/videoprism")
|
||||
>>> model = VideoPrismForVideoClassification.from_pretrained("google/videoprism", num_labels=1000)
|
||||
|
||||
>>> video = "sample_video.mp4"
|
||||
>>> inputs = processor(videos=video, return_tensors="pt")
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(**inputs)
|
||||
... logits = outputs.logits
|
||||
```
|
||||
"""
|
||||
encoder_outputs = self.encoder(
|
||||
pixel_values_videos=pixel_values_videos, interpolate_pos_encoding=interpolate_pos_encoding, **kwargs
|
||||
)
|
||||
sequence_output = encoder_outputs.last_hidden_state
|
||||
pooled_output = self.contrastive_vision_pooler(sequence_output, **kwargs).pooled_output
|
||||
logits = self.classifier(pooled_output)
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(labels, logits, self.config, **kwargs)
|
||||
|
||||
return ImageClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=encoder_outputs.last_hidden_state,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"VideoPrismVisionModel",
|
||||
"VideoPrismPreTrainedModel",
|
||||
"VideoPrismVideoModel",
|
||||
"VideoPrismTextModel",
|
||||
"VideoPrismClipModel",
|
||||
"VideoPrismForVideoClassification",
|
||||
]
|
||||
@@ -1,50 +0,0 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torchcodec.decoders import VideoDecoder
|
||||
|
||||
from lerobot.policies.videovla.videoprism import VideoPrismVideoProcessor
|
||||
from lerobot.policies.videovla.videoprism import VideoPrismVisionModel
|
||||
processor = VideoPrismVideoProcessor.from_pretrained(
|
||||
"MHRDYN7/videoprism-base-f16r288"
|
||||
)
|
||||
|
||||
model = VideoPrismVisionModel.from_pretrained(
|
||||
"MHRDYN7/videoprism-base-f16r288",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa",
|
||||
)
|
||||
|
||||
video_url = "https://huggingface.co/datasets/nateraw/kinetics-mini/resolve/main/val/archery/-Qz25rXdMjE_000014_000024.mp4"
|
||||
|
||||
vr = VideoDecoder(video_url)
|
||||
frame_idx = np.arange(0, 64)
|
||||
video = vr.get_frames_at(indices=frame_idx).data # T x C x H x W
|
||||
|
||||
video = processor(video, return_tensors="pt")
|
||||
video = {k: v.to(model.device, model.dtype) for k, v in video.items()}
|
||||
outputs = model(**video)
|
||||
encoder_outputs = outputs.last_hidden_state
|
||||
print(encoder_outputs.shape) #
|
||||
|
||||
import time
|
||||
import torch
|
||||
|
||||
# warmup
|
||||
for _ in range(10):
|
||||
_ = model(**video)
|
||||
|
||||
times = []
|
||||
for _ in range(50):
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.perf_counter()
|
||||
|
||||
_ = model(**video)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.perf_counter()
|
||||
times.append(t1 - t0)
|
||||
|
||||
print(f"Mean: {1000*sum(times)/len(times):.2f} ms")
|
||||
print(f"Min : {1000*min(times):.2f} ms")
|
||||
print(f"Max : {1000*max(times):.2f} ms")
|
||||
@@ -1,44 +0,0 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/videoprism/modular_videoprism.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_videoprism.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
|
||||
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling
|
||||
from transformers.video_processing_utils import BaseVideoProcessor
|
||||
|
||||
|
||||
class VideoPrismVideoProcessor(BaseVideoProcessor):
|
||||
r"""
|
||||
Constructs a VideoPrism video processor.
|
||||
|
||||
This processor inherits from [`LlavaOnevisionVideoProcessor`] and sets default parameters for VideoPrism models.
|
||||
Video frames are resized to 288x288 using bicubic resampling without normalization.
|
||||
|
||||
Args:
|
||||
size (`Dict[str, int]`, *optional*, defaults to `{"height": 288, "width": 288}`):
|
||||
The size to resize the video frames to.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||
The resampling filter to use when resizing images.
|
||||
do_normalize (`bool`, *optional*, defaults to `False`):
|
||||
Whether to normalize the video frames.
|
||||
"""
|
||||
|
||||
resample = PILImageResampling.BICUBIC
|
||||
image_mean = OPENAI_CLIP_MEAN
|
||||
image_std = OPENAI_CLIP_STD
|
||||
|
||||
size = {"height": 288, "width": 288}
|
||||
rescale_factor = 1 / 255
|
||||
default_to_square = False
|
||||
crop_size = None
|
||||
do_resize = True
|
||||
do_center_crop = None
|
||||
do_rescale = True
|
||||
do_normalize = False
|
||||
do_convert_rgb = True
|
||||
do_sample_frames = False # Set to False for BC, recommended to set `True` in new models
|
||||
|
||||
|
||||
__all__ = ["VideoPrismVideoProcessor"]
|
||||
@@ -63,8 +63,8 @@ def update_policy(
|
||||
accelerator: Accelerator,
|
||||
lr_scheduler=None,
|
||||
lock=None,
|
||||
rabc_weights_provider=None,
|
||||
) -> tuple[MetricsTracker, dict]:
|
||||
sample_weighter=None,
|
||||
) -> tuple[MetricsTracker, dict | None]:
|
||||
"""
|
||||
Performs a single training step to update the policy's weights.
|
||||
|
||||
@@ -80,7 +80,7 @@ def update_policy(
|
||||
accelerator: The Accelerator instance for distributed training and mixed precision.
|
||||
lr_scheduler: An optional learning rate scheduler.
|
||||
lock: An optional lock for thread-safe optimizer updates.
|
||||
rabc_weights_provider: Optional RABCWeights instance for sample weighting.
|
||||
sample_weighter: Optional SampleWeighter instance for per-sample loss weighting.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
@@ -90,27 +90,31 @@ def update_policy(
|
||||
start_time = time.perf_counter()
|
||||
policy.train()
|
||||
|
||||
# Get RA-BC weights if enabled
|
||||
rabc_batch_weights = None
|
||||
rabc_batch_stats = None
|
||||
if rabc_weights_provider is not None:
|
||||
rabc_batch_weights, rabc_batch_stats = rabc_weights_provider.compute_batch_weights(batch)
|
||||
# Compute sample weights if a weighter is provided
|
||||
sample_weights = None
|
||||
weight_stats = None
|
||||
if sample_weighter is not None:
|
||||
sample_weights, weight_stats = sample_weighter.compute_batch_weights(batch)
|
||||
|
||||
# Let accelerator handle mixed precision
|
||||
with accelerator.autocast():
|
||||
# Use per-sample loss when RA-BC is enabled for proper weighting
|
||||
if rabc_batch_weights is not None:
|
||||
# Get per-sample losses
|
||||
if sample_weights is not None:
|
||||
# Use per-sample loss for weighted training
|
||||
# Note: Policies supporting sample weighting must implement forward(batch, reduction="none")
|
||||
per_sample_loss, output_dict = policy.forward(batch, reduction="none")
|
||||
|
||||
# Apply RA-BC weights: L_RA-BC = Σ(w_i * l_i) / (Σw_i + ε)
|
||||
# rabc_batch_weights is already normalized to sum to batch_size
|
||||
# Weighted loss: each sample's contribution is scaled by its weight.
|
||||
# We divide by weight sum (not batch size) so that if some weights are zero,
|
||||
# the remaining samples contribute proportionally more, preserving gradient scale.
|
||||
# Weights are pre-normalized to sum to batch_size for stable training dynamics.
|
||||
epsilon = 1e-6
|
||||
loss = (per_sample_loss * rabc_batch_weights).sum() / (rabc_batch_weights.sum() + epsilon)
|
||||
# Log raw mean weight (before normalization) - this is the meaningful metric
|
||||
output_dict["rabc_mean_weight"] = rabc_batch_stats["raw_mean_weight"]
|
||||
output_dict["rabc_num_zero_weight"] = rabc_batch_stats["num_zero_weight"]
|
||||
output_dict["rabc_num_full_weight"] = rabc_batch_stats["num_full_weight"]
|
||||
loss = (per_sample_loss * sample_weights).sum() / (sample_weights.sum() + epsilon)
|
||||
|
||||
# Log weighting statistics
|
||||
if output_dict is None:
|
||||
output_dict = {}
|
||||
for key, value in weight_stats.items():
|
||||
output_dict[f"sample_weight_{key}"] = value
|
||||
else:
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
@@ -288,27 +292,19 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
|
||||
# Load precomputed SARM progress for RA-BC if enabled
|
||||
# Generate progress using: src/lerobot/policies/sarm/compute_rabc_weights.py
|
||||
rabc_weights = None
|
||||
if cfg.use_rabc:
|
||||
from lerobot.utils.rabc import RABCWeights
|
||||
# Create sample weighter if configured (e.g., for RA-BC training)
|
||||
sample_weighter = None
|
||||
if cfg.sample_weighting is not None:
|
||||
from lerobot.utils.sample_weighting import make_sample_weighter
|
||||
|
||||
# Get chunk_size from policy config
|
||||
chunk_size = getattr(policy.config, "chunk_size", None)
|
||||
if chunk_size is None:
|
||||
raise ValueError("Chunk size is not found in policy config")
|
||||
|
||||
head_mode = getattr(cfg, "rabc_head_mode", "sparse")
|
||||
logging.info(f"Loading SARM progress for RA-BC from {cfg.rabc_progress_path}")
|
||||
logging.info(f"Using chunk_size={chunk_size} from policy config, head_mode={head_mode}")
|
||||
rabc_weights = RABCWeights(
|
||||
progress_path=cfg.rabc_progress_path,
|
||||
chunk_size=chunk_size,
|
||||
head_mode=head_mode,
|
||||
kappa=getattr(cfg, "rabc_kappa", 0.01),
|
||||
epsilon=getattr(cfg, "rabc_epsilon", 1e-6),
|
||||
device=device,
|
||||
if is_main_process:
|
||||
logging.info(f"Creating sample weighter: {cfg.sample_weighting.type}")
|
||||
sample_weighter = make_sample_weighter(
|
||||
cfg.sample_weighting,
|
||||
policy,
|
||||
device,
|
||||
dataset_root=cfg.dataset.root,
|
||||
dataset_repo_id=cfg.dataset.repo_id,
|
||||
)
|
||||
|
||||
step = 0 # number of policy updates (forward + backward + optim)
|
||||
@@ -408,7 +404,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
cfg.optimizer.grad_clip_norm,
|
||||
accelerator=accelerator,
|
||||
lr_scheduler=lr_scheduler,
|
||||
rabc_weights_provider=rabc_weights,
|
||||
sample_weighter=sample_weighter,
|
||||
)
|
||||
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
@@ -425,16 +421,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
wandb_log_dict = train_tracker.to_dict()
|
||||
if output_dict:
|
||||
wandb_log_dict.update(output_dict)
|
||||
# Log RA-BC statistics if enabled
|
||||
if rabc_weights is not None:
|
||||
rabc_stats = rabc_weights.get_stats()
|
||||
wandb_log_dict.update(
|
||||
{
|
||||
"rabc_delta_mean": rabc_stats["delta_mean"],
|
||||
"rabc_delta_std": rabc_stats["delta_std"],
|
||||
"rabc_num_frames": rabc_stats["num_frames"],
|
||||
}
|
||||
)
|
||||
# Log sample weighting statistics if enabled
|
||||
if sample_weighter is not None:
|
||||
weighter_stats = sample_weighter.get_stats()
|
||||
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
|
||||
wandb_logger.log_dict(wandb_log_dict, step)
|
||||
train_tracker.reset_averages()
|
||||
|
||||
|
||||
@@ -27,4 +27,4 @@ class OmxLeaderConfig(TeleoperatorConfig):
|
||||
|
||||
# Sets the arm in torque mode with the gripper motor set to this value. This makes it possible to squeeze
|
||||
# the gripper and have it spring back to an open position on its own.
|
||||
gripper_open_pos: float = 37.0
|
||||
gripper_open_pos: float = 60.0
|
||||
|
||||
@@ -103,7 +103,7 @@ class OmxLeader(Teleoperator):
|
||||
self.calibration[motor] = MotorCalibration(
|
||||
id=m.id,
|
||||
drive_mode=drive_modes[motor],
|
||||
homing_offset=0,
|
||||
homing_offset=0 if motor != "gripper" else 100,
|
||||
range_min=0,
|
||||
range_max=4095,
|
||||
)
|
||||
@@ -123,12 +123,20 @@ class OmxLeader(Teleoperator):
|
||||
# point
|
||||
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
|
||||
|
||||
if motor == "gripper":
|
||||
self.bus.write("Drive_Mode", motor, DriveMode.INVERTED.value)
|
||||
else:
|
||||
self.bus.write("Drive_Mode", motor, DriveMode.NON_INVERTED.value)
|
||||
|
||||
# Use 'position control current based' for gripper to be limited by the limit of the current.
|
||||
# For the follower gripper, it means it can grasp an object without forcing too much even tho,
|
||||
# its goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch).
|
||||
# For the leader gripper, it means we can use it as a physical trigger, since we can force with our finger
|
||||
# to make it move, and it will move back to its original target position when we release the force.
|
||||
self.bus.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value)
|
||||
self.bus.write("Current_Limit", "gripper", 100)
|
||||
self.bus.write("Goal_Current", "gripper", 100)
|
||||
self.bus.write("Homing_Offset", "gripper", 100)
|
||||
# Set gripper's goal pos in current position mode so that we can use it as a trigger.
|
||||
self.bus.enable_torque("gripper")
|
||||
if self.is_calibrated:
|
||||
|
||||
239
src/lerobot/utils/sample_weighting.py
Normal file
239
src/lerobot/utils/sample_weighting.py
Normal file
@@ -0,0 +1,239 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Sample weighting abstraction for training.
|
||||
|
||||
This module provides an abstract base class for sample weighting strategies (e.g., RA-BC)
|
||||
that can be used during training without polluting the training script with
|
||||
policy-specific code.
|
||||
|
||||
Example usage:
|
||||
# In training config
|
||||
sample_weighting:
|
||||
type: rabc
|
||||
progress_path: hf://datasets/my-dataset/sarm_progress.parquet
|
||||
head_mode: sparse
|
||||
kappa: 0.01
|
||||
|
||||
# In training script
|
||||
sample_weighter = make_sample_weighter(cfg.sample_weighting, policy, device, dataset_root=cfg.dataset.root, dataset_repo_id=cfg.dataset.repo_id)
|
||||
...
|
||||
weights, stats = sample_weighter.compute_batch_weights(batch)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
|
||||
class SampleWeighter(ABC):
|
||||
"""
|
||||
Implementations compute per-sample weights that can be used to weight
|
||||
the loss during training. This enables techniques like:
|
||||
- RA-BC (Reward-Aligned Behavior Cloning)
|
||||
- Importance sampling
|
||||
- Curriculum learning
|
||||
- Quality-based filtering
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]:
|
||||
"""
|
||||
Compute per-sample weights for a training batch.
|
||||
|
||||
Args:
|
||||
batch: Training batch dictionary containing at minimum an "index" key
|
||||
with global frame indices.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_stats(self) -> dict:
|
||||
"""
|
||||
Get global statistics about the weighting strategy.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class SampleWeightingConfig:
|
||||
"""
|
||||
Configuration for sample weighting during training.
|
||||
|
||||
This is a generic config that supports multiple weighting strategies.
|
||||
The `type` field determines which implementation to use, and `extra_params`
|
||||
contains additional type-specific parameters.
|
||||
|
||||
Attributes:
|
||||
type: Weighting strategy type ("rabc", "uniform", etc.)
|
||||
progress_path: Path to precomputed progress values (for RABC)
|
||||
head_mode: Which model head to use for progress ("sparse" or "dense")
|
||||
kappa: Hard threshold for high-quality samples (RABC-specific)
|
||||
epsilon: Small constant for numerical stability
|
||||
extra_params: Additional type-specific parameters passed to the weighter
|
||||
"""
|
||||
|
||||
type: str = "rabc"
|
||||
progress_path: str | None = None
|
||||
head_mode: str = "sparse"
|
||||
kappa: float = 0.01
|
||||
epsilon: float = 1e-6
|
||||
# Additional type-specific params can be added here or passed via extra_params
|
||||
extra_params: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
def make_sample_weighter(
|
||||
config: SampleWeightingConfig | None,
|
||||
policy: PreTrainedPolicy,
|
||||
device: torch.device,
|
||||
dataset_root: str | None = None,
|
||||
dataset_repo_id: str | None = None,
|
||||
) -> SampleWeighter | None:
|
||||
"""
|
||||
Factory function to create a SampleWeighter from config.
|
||||
|
||||
This keeps policy-specific initialization logic out of the training script.
|
||||
|
||||
Args:
|
||||
config: Sample weighting configuration, or None to disable weighting.
|
||||
policy: The policy being trained (used to extract chunk_size, etc.)
|
||||
device: Device to place weight tensors on.
|
||||
dataset_root: Local path to dataset root (for auto-detecting progress_path).
|
||||
dataset_repo_id: HuggingFace repo ID (for auto-detecting progress_path).
|
||||
"""
|
||||
if config is None:
|
||||
return None
|
||||
|
||||
if config.type == "rabc":
|
||||
return _make_rabc_weighter(config, policy, device, dataset_root, dataset_repo_id)
|
||||
|
||||
if config.type == "uniform":
|
||||
# No-op weighter that returns uniform weights
|
||||
return UniformWeighter(device=device)
|
||||
|
||||
raise ValueError(f"Unknown sample weighting type: '{config.type}'. Supported types: 'rabc', 'uniform'")
|
||||
|
||||
|
||||
def _make_rabc_weighter(
|
||||
config: SampleWeightingConfig,
|
||||
policy: PreTrainedPolicy,
|
||||
device: torch.device,
|
||||
dataset_root: str | None = None,
|
||||
dataset_repo_id: str | None = None,
|
||||
) -> SampleWeighter:
|
||||
"""Create RABC weighter with policy-specific initialization.
|
||||
|
||||
Args:
|
||||
config: Sample weighting configuration.
|
||||
policy: The policy being trained (used to extract chunk_size).
|
||||
device: Device to place weight tensors on.
|
||||
dataset_root: Local path to dataset root (for auto-detecting progress_path).
|
||||
dataset_repo_id: HuggingFace repo ID (for auto-detecting progress_path).
|
||||
"""
|
||||
# Import here to avoid circular imports and keep RABC code in SARM module
|
||||
from lerobot.policies.sarm.rabc import RABCWeights
|
||||
|
||||
# Extract chunk_size from policy config
|
||||
chunk_size = getattr(policy.config, "chunk_size", None)
|
||||
if chunk_size is None:
|
||||
raise ValueError(
|
||||
"RABC sample weighting requires a policy with 'chunk_size' in its config. "
|
||||
"This is typically set for action-chunking policies like ACT, Diffusion, PI0, etc."
|
||||
)
|
||||
|
||||
# Determine progress_path: use explicit config or auto-detect from dataset
|
||||
progress_path = config.progress_path
|
||||
if progress_path is None:
|
||||
if dataset_root:
|
||||
progress_path = str(Path(dataset_root) / "sarm_progress.parquet")
|
||||
elif dataset_repo_id:
|
||||
progress_path = f"hf://datasets/{dataset_repo_id}/sarm_progress.parquet"
|
||||
else:
|
||||
raise ValueError(
|
||||
"RABC sample weighting requires 'progress_path' to be set, "
|
||||
"or dataset_root/dataset_repo_id for auto-detection. "
|
||||
"Generate progress values using: "
|
||||
"python -m lerobot.policies.sarm.compute_rabc_weights --help"
|
||||
)
|
||||
|
||||
return RABCWeights(
|
||||
progress_path=progress_path,
|
||||
chunk_size=chunk_size,
|
||||
head_mode=config.head_mode,
|
||||
kappa=config.kappa,
|
||||
epsilon=config.epsilon,
|
||||
device=device,
|
||||
**config.extra_params,
|
||||
)
|
||||
|
||||
|
||||
class UniformWeighter(SampleWeighter):
|
||||
"""
|
||||
No-op sample weighter that returns uniform weights.
|
||||
|
||||
Useful as a baseline or when you want to disable weighting without
|
||||
changing the training code structure.
|
||||
|
||||
Note:
|
||||
Batch size is determined by looking for tensor values in the batch
|
||||
dictionary. The method checks common keys like "action", "index",
|
||||
and "observation.state" first, then falls back to scanning all values.
|
||||
"""
|
||||
|
||||
def __init__(self, device: torch.device):
|
||||
self.device = device
|
||||
|
||||
def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]:
|
||||
"""Return uniform weights (all ones)."""
|
||||
batch_size = self._determine_batch_size(batch)
|
||||
|
||||
weights = torch.ones(batch_size, device=self.device)
|
||||
stats = {"mean_weight": 1.0, "type": "uniform"}
|
||||
return weights, stats
|
||||
|
||||
def _determine_batch_size(self, batch: dict) -> int:
|
||||
"""
|
||||
Determine batch size from the batch dictionary.
|
||||
|
||||
Checks common keys first, then scans all values for tensors.
|
||||
|
||||
Args:
|
||||
batch: Training batch dictionary.
|
||||
"""
|
||||
if not batch:
|
||||
raise ValueError("Cannot determine batch size from empty batch")
|
||||
|
||||
# Check common keys first
|
||||
for key in ["action", "index", "observation.state"]:
|
||||
if key in batch and isinstance(batch[key], torch.Tensor):
|
||||
return batch[key].shape[0]
|
||||
|
||||
# Scan all values for any tensor
|
||||
for value in batch.values():
|
||||
if isinstance(value, torch.Tensor) and value.ndim >= 1:
|
||||
return value.shape[0]
|
||||
|
||||
# Last resort: return 1 (this handles non-tensor batches)
|
||||
return 1
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Return empty stats for uniform weighting."""
|
||||
return {"type": "uniform"}
|
||||
398
tests/utils/test_sample_weighting.py
Normal file
398
tests/utils/test_sample_weighting.py
Normal file
@@ -0,0 +1,398 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the sample weighting infrastructure."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.utils.sample_weighting import (
|
||||
SampleWeighter,
|
||||
SampleWeightingConfig,
|
||||
UniformWeighter,
|
||||
make_sample_weighter,
|
||||
)
|
||||
|
||||
# =============================================================================
|
||||
# Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_progress_parquet(tmp_path):
|
||||
"""Create a sample progress parquet file for testing."""
|
||||
import pandas as pd
|
||||
|
||||
# Create sample progress data for 2 episodes with 10 frames each
|
||||
data = {
|
||||
"index": list(range(20)),
|
||||
"episode_index": [0] * 10 + [1] * 10,
|
||||
"frame_index": list(range(10)) * 2,
|
||||
"progress_sparse": [i / 10.0 for i in range(10)] * 2,
|
||||
}
|
||||
df = pd.DataFrame(data)
|
||||
parquet_path = tmp_path / "sarm_progress.parquet"
|
||||
df.to_parquet(parquet_path)
|
||||
return parquet_path
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SampleWeightingConfig Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_config_default_values():
|
||||
"""Test default configuration values."""
|
||||
config = SampleWeightingConfig()
|
||||
assert config.type == "rabc"
|
||||
assert config.progress_path is None
|
||||
assert config.head_mode == "sparse"
|
||||
assert config.kappa == 0.01
|
||||
assert config.epsilon == 1e-6
|
||||
assert config.extra_params == {}
|
||||
|
||||
|
||||
def test_config_custom_values():
|
||||
"""Test configuration with custom values."""
|
||||
config = SampleWeightingConfig(
|
||||
type="rabc",
|
||||
progress_path="/path/to/progress.parquet",
|
||||
head_mode="dense",
|
||||
kappa=0.05,
|
||||
epsilon=1e-8,
|
||||
extra_params={"fallback_weight": 0.5},
|
||||
)
|
||||
assert config.type == "rabc"
|
||||
assert config.progress_path == "/path/to/progress.parquet"
|
||||
assert config.head_mode == "dense"
|
||||
assert config.kappa == 0.05
|
||||
assert config.epsilon == 1e-8
|
||||
assert config.extra_params == {"fallback_weight": 0.5}
|
||||
|
||||
|
||||
def test_config_uniform_type():
|
||||
"""Test configuration for uniform weighting."""
|
||||
config = SampleWeightingConfig(type="uniform")
|
||||
assert config.type == "uniform"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# UniformWeighter Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_uniform_weighter_inherits_from_sample_weighter():
|
||||
"""Test that UniformWeighter is a SampleWeighter."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
assert isinstance(weighter, SampleWeighter)
|
||||
|
||||
|
||||
def test_uniform_weighter_compute_batch_weights_with_action_key():
|
||||
"""Test weight computation with 'action' key in batch."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
batch = {"action": torch.randn(8, 10)}
|
||||
|
||||
weights, stats = weighter.compute_batch_weights(batch)
|
||||
|
||||
assert weights.shape == (8,)
|
||||
assert torch.allclose(weights, torch.ones(8))
|
||||
assert stats["mean_weight"] == 1.0
|
||||
assert stats["type"] == "uniform"
|
||||
|
||||
|
||||
def test_uniform_weighter_compute_batch_weights_with_index_key():
|
||||
"""Test weight computation with 'index' key in batch."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
batch = {"index": torch.arange(16)}
|
||||
|
||||
weights, stats = weighter.compute_batch_weights(batch)
|
||||
|
||||
assert weights.shape == (16,)
|
||||
assert torch.allclose(weights, torch.ones(16))
|
||||
|
||||
|
||||
def test_uniform_weighter_compute_batch_weights_no_tensor_keys():
|
||||
"""Test weight computation with no tensor keys (fallback to size 1)."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
batch = {"other_key": "some_value"}
|
||||
|
||||
weights, stats = weighter.compute_batch_weights(batch)
|
||||
|
||||
assert weights.shape == (1,)
|
||||
assert torch.allclose(weights, torch.ones(1))
|
||||
|
||||
|
||||
def test_uniform_weighter_compute_batch_weights_empty_batch_raises():
|
||||
"""Test that empty batch raises ValueError."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
batch = {}
|
||||
|
||||
with pytest.raises(ValueError, match="empty batch"):
|
||||
weighter.compute_batch_weights(batch)
|
||||
|
||||
|
||||
def test_uniform_weighter_compute_batch_weights_scans_all_keys():
|
||||
"""Test that batch size is determined by scanning all tensor values."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
# Batch with non-standard key containing a tensor
|
||||
batch = {"custom_tensor": torch.randn(7, 3)}
|
||||
|
||||
weights, stats = weighter.compute_batch_weights(batch)
|
||||
|
||||
assert weights.shape == (7,)
|
||||
assert torch.allclose(weights, torch.ones(7))
|
||||
|
||||
|
||||
def test_uniform_weighter_compute_batch_weights_on_cuda():
|
||||
"""Test that weights are placed on the correct device."""
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA not available")
|
||||
|
||||
weighter = UniformWeighter(device=torch.device("cuda"))
|
||||
batch = {"action": torch.randn(4, 10)}
|
||||
|
||||
weights, _ = weighter.compute_batch_weights(batch)
|
||||
|
||||
assert weights.device.type == "cuda"
|
||||
|
||||
|
||||
def test_uniform_weighter_get_stats():
|
||||
"""Test get_stats returns expected structure."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
stats = weighter.get_stats()
|
||||
|
||||
assert stats == {"type": "uniform"}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# make_sample_weighter Factory Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_factory_returns_none_for_none_config():
|
||||
"""Test that None config returns None weighter."""
|
||||
policy = Mock()
|
||||
device = torch.device("cpu")
|
||||
|
||||
result = make_sample_weighter(None, policy, device)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_factory_creates_uniform_weighter():
|
||||
"""Test creation of UniformWeighter."""
|
||||
config = SampleWeightingConfig(type="uniform")
|
||||
policy = Mock()
|
||||
device = torch.device("cpu")
|
||||
|
||||
weighter = make_sample_weighter(config, policy, device)
|
||||
|
||||
assert isinstance(weighter, UniformWeighter)
|
||||
assert isinstance(weighter, SampleWeighter)
|
||||
|
||||
|
||||
def test_factory_raises_for_unknown_type():
|
||||
"""Test that unknown type raises ValueError."""
|
||||
config = SampleWeightingConfig(type="unknown_type")
|
||||
policy = Mock()
|
||||
device = torch.device("cpu")
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown sample weighting type"):
|
||||
make_sample_weighter(config, policy, device)
|
||||
|
||||
|
||||
def test_factory_rabc_requires_chunk_size():
|
||||
"""Test that RABC weighter requires chunk_size in policy config."""
|
||||
config = SampleWeightingConfig(
|
||||
type="rabc",
|
||||
progress_path="/path/to/progress.parquet",
|
||||
)
|
||||
policy = Mock()
|
||||
policy.config = Mock()
|
||||
policy.config.chunk_size = None # No chunk_size
|
||||
device = torch.device("cpu")
|
||||
|
||||
with pytest.raises(ValueError, match="chunk_size"):
|
||||
make_sample_weighter(config, policy, device)
|
||||
|
||||
|
||||
def test_factory_rabc_requires_progress_path_or_dataset_info():
|
||||
"""Test that RABC weighter requires progress_path or dataset info for auto-detection."""
|
||||
config = SampleWeightingConfig(
|
||||
type="rabc",
|
||||
progress_path=None, # No progress path
|
||||
)
|
||||
policy = Mock()
|
||||
policy.config = Mock()
|
||||
policy.config.chunk_size = 50
|
||||
device = torch.device("cpu")
|
||||
|
||||
# Should fail when no progress_path AND no dataset info
|
||||
with pytest.raises(ValueError, match="progress_path"):
|
||||
make_sample_weighter(config, policy, device)
|
||||
|
||||
|
||||
def test_factory_rabc_auto_detects_from_dataset_root(sample_progress_parquet):
|
||||
"""Test that RABC weighter auto-detects progress_path from dataset_root."""
|
||||
config = SampleWeightingConfig(
|
||||
type="rabc",
|
||||
progress_path=None, # Not provided, should auto-detect
|
||||
)
|
||||
policy = Mock()
|
||||
policy.config = Mock()
|
||||
policy.config.chunk_size = 5
|
||||
device = torch.device("cpu")
|
||||
|
||||
# The parquet file is at sample_progress_parquet, get its parent directory
|
||||
dataset_root = sample_progress_parquet.parent
|
||||
weighter = make_sample_weighter(
|
||||
config,
|
||||
policy,
|
||||
device,
|
||||
dataset_root=str(dataset_root),
|
||||
)
|
||||
|
||||
assert weighter is not None
|
||||
from lerobot.policies.sarm.rabc import RABCWeights
|
||||
|
||||
assert isinstance(weighter, RABCWeights)
|
||||
|
||||
|
||||
def test_factory_rabc_auto_detects_from_repo_id():
|
||||
"""Test that RABC weighter constructs HF path from repo_id."""
|
||||
config = SampleWeightingConfig(
|
||||
type="rabc",
|
||||
progress_path=None, # Not provided, should auto-detect
|
||||
)
|
||||
policy = Mock()
|
||||
policy.config = Mock()
|
||||
policy.config.chunk_size = 50
|
||||
device = torch.device("cpu")
|
||||
|
||||
# This will construct the path but fail when trying to load (file doesn't exist)
|
||||
# We just verify it doesn't raise the "progress_path required" error
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
make_sample_weighter(
|
||||
config,
|
||||
policy,
|
||||
device,
|
||||
dataset_repo_id="test-user/test-dataset",
|
||||
)
|
||||
# Should NOT be the "progress_path required" error - it should try to load the file
|
||||
assert (
|
||||
"progress_path" not in str(exc_info.value).lower() or "auto-detection" in str(exc_info.value).lower()
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Integration Tests with RABCWeights
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_rabc_weights_is_sample_weighter(sample_progress_parquet):
|
||||
"""Test that RABCWeights inherits from SampleWeighter."""
|
||||
from lerobot.policies.sarm.rabc import RABCWeights
|
||||
|
||||
weighter = RABCWeights(
|
||||
progress_path=sample_progress_parquet,
|
||||
chunk_size=5,
|
||||
head_mode="sparse",
|
||||
)
|
||||
assert isinstance(weighter, SampleWeighter)
|
||||
|
||||
|
||||
def test_rabc_compute_batch_weights(sample_progress_parquet):
|
||||
"""Test RABCWeights.compute_batch_weights returns correct structure."""
|
||||
from lerobot.policies.sarm.rabc import RABCWeights
|
||||
|
||||
weighter = RABCWeights(
|
||||
progress_path=sample_progress_parquet,
|
||||
chunk_size=5,
|
||||
head_mode="sparse",
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
|
||||
batch = {"index": torch.tensor([0, 1, 2, 3])}
|
||||
weights, stats = weighter.compute_batch_weights(batch)
|
||||
|
||||
assert isinstance(weights, torch.Tensor)
|
||||
assert weights.shape == (4,)
|
||||
assert isinstance(stats, dict)
|
||||
assert "mean_weight" in stats
|
||||
|
||||
|
||||
def test_rabc_get_stats(sample_progress_parquet):
|
||||
"""Test RABCWeights.get_stats returns expected structure."""
|
||||
from lerobot.policies.sarm.rabc import RABCWeights
|
||||
|
||||
weighter = RABCWeights(
|
||||
progress_path=sample_progress_parquet,
|
||||
chunk_size=5,
|
||||
head_mode="sparse",
|
||||
)
|
||||
|
||||
stats = weighter.get_stats()
|
||||
|
||||
assert stats["type"] == "rabc"
|
||||
assert "num_frames" in stats
|
||||
assert "chunk_size" in stats
|
||||
assert stats["chunk_size"] == 5
|
||||
assert "head_mode" in stats
|
||||
assert stats["head_mode"] == "sparse"
|
||||
assert "delta_mean" in stats
|
||||
assert "delta_std" in stats
|
||||
|
||||
|
||||
def test_factory_creates_rabc_weighter(sample_progress_parquet):
|
||||
"""Test factory creates RABCWeights with valid config."""
|
||||
from lerobot.policies.sarm.rabc import RABCWeights
|
||||
|
||||
config = SampleWeightingConfig(
|
||||
type="rabc",
|
||||
progress_path=str(sample_progress_parquet),
|
||||
head_mode="sparse",
|
||||
kappa=0.01,
|
||||
)
|
||||
policy = Mock()
|
||||
policy.config = Mock()
|
||||
policy.config.chunk_size = 5
|
||||
device = torch.device("cpu")
|
||||
|
||||
weighter = make_sample_weighter(config, policy, device)
|
||||
|
||||
assert isinstance(weighter, RABCWeights)
|
||||
assert isinstance(weighter, SampleWeighter)
|
||||
|
||||
|
||||
def test_rabc_weights_normalization(sample_progress_parquet):
|
||||
"""Test that RABCWeights normalizes weights to sum to batch_size."""
|
||||
from lerobot.policies.sarm.rabc import RABCWeights
|
||||
|
||||
weighter = RABCWeights(
|
||||
progress_path=sample_progress_parquet,
|
||||
chunk_size=5,
|
||||
head_mode="sparse",
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
|
||||
batch = {"index": torch.tensor([0, 1, 2, 3])}
|
||||
weights, _ = weighter.compute_batch_weights(batch)
|
||||
|
||||
# Weights should be normalized to sum approximately to batch_size
|
||||
batch_size = 4
|
||||
assert abs(weights.sum().item() - batch_size) < 0.1
|
||||
Reference in New Issue
Block a user