|
|
|
|
@@ -32,6 +32,7 @@ from lerobot.utils.import_utils import _transformers_available
|
|
|
|
|
|
|
|
|
|
# Conditional import for type checking and lazy loading
|
|
|
|
|
if TYPE_CHECKING or _transformers_available:
|
|
|
|
|
from transformers import AutoTokenizer
|
|
|
|
|
from transformers.models.auto import CONFIG_MAPPING
|
|
|
|
|
from transformers.models.gemma import modeling_gemma
|
|
|
|
|
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
|
|
|
|
|
@@ -41,6 +42,7 @@ else:
|
|
|
|
|
modeling_gemma = None
|
|
|
|
|
GemmaForCausalLM = None
|
|
|
|
|
PaliGemmaForConditionalGeneration = None
|
|
|
|
|
AutoTokenizer = None
|
|
|
|
|
|
|
|
|
|
from lerobot.configs.policies import PreTrainedConfig
|
|
|
|
|
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
|
|
|
|
|
@@ -48,6 +50,8 @@ from lerobot.policies.pretrained import PreTrainedPolicy, T
|
|
|
|
|
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
|
|
|
|
from lerobot.utils.constants import (
|
|
|
|
|
ACTION,
|
|
|
|
|
ACTION_TOKEN_MASK,
|
|
|
|
|
ACTION_TOKENS,
|
|
|
|
|
OBS_LANGUAGE_ATTENTION_MASK,
|
|
|
|
|
OBS_LANGUAGE_TOKENS,
|
|
|
|
|
OPENPI_ATTENTION_MASK_VALUE,
|
|
|
|
|
@@ -55,81 +59,7 @@ from lerobot.utils.constants import (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ActionSelectKwargs(TypedDict, total=False):
|
|
|
|
|
inference_delay: int | None
|
|
|
|
|
prev_chunk_left_over: Tensor | None
|
|
|
|
|
execution_horizon: int | None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_safe_dtype(target_dtype, device_type):
|
|
|
|
|
"""Get a safe dtype for the given device type."""
|
|
|
|
|
if device_type == "mps" and target_dtype == torch.float64:
|
|
|
|
|
return torch.float32
|
|
|
|
|
if device_type == "cpu":
|
|
|
|
|
# CPU doesn't support bfloat16, use float32 instead
|
|
|
|
|
if target_dtype == torch.bfloat16:
|
|
|
|
|
return torch.float32
|
|
|
|
|
if target_dtype == torch.float64:
|
|
|
|
|
return torch.float64
|
|
|
|
|
return target_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy)
|
|
|
|
|
time: torch.Tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
|
|
|
|
) -> Tensor:
|
|
|
|
|
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
|
|
|
|
if dimension % 2 != 0:
|
|
|
|
|
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
|
|
|
|
|
|
|
|
|
if time.ndim != 1:
|
|
|
|
|
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
|
|
|
|
|
|
|
|
|
dtype = get_safe_dtype(torch.float64, device.type)
|
|
|
|
|
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
|
|
|
|
period = min_period * (max_period / min_period) ** fraction
|
|
|
|
|
|
|
|
|
|
# Compute the outer product
|
|
|
|
|
scaling_factor = 1.0 / period * 2 * math.pi
|
|
|
|
|
sin_input = scaling_factor[None, :] * time[:, None]
|
|
|
|
|
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
|
|
|
|
|
alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
|
|
|
|
|
beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
|
|
|
|
|
dist = torch.distributions.Beta(alpha_t, beta_t)
|
|
|
|
|
return dist.sample((bsize,))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy)
|
|
|
|
|
"""Copied from big_vision.
|
|
|
|
|
|
|
|
|
|
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
|
|
|
|
|
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
|
|
|
|
|
setup several types of attention, for example:
|
|
|
|
|
|
|
|
|
|
[[1 1 1 1 1 1]]: pure causal attention.
|
|
|
|
|
|
|
|
|
|
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
|
|
|
|
|
themselves and the last 3 tokens have a causal attention. The first
|
|
|
|
|
entry could also be a 1 without changing behaviour.
|
|
|
|
|
|
|
|
|
|
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
|
|
|
|
|
block can attend all previous blocks and all tokens on the same block.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
input_mask: bool[B, N] true if its part of the input, false if padding.
|
|
|
|
|
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
|
|
|
|
|
it and 0 where it shares the same attention mask as the previous token.
|
|
|
|
|
"""
|
|
|
|
|
if att_masks.ndim != 2:
|
|
|
|
|
raise ValueError(att_masks.ndim)
|
|
|
|
|
if pad_masks.ndim != 2:
|
|
|
|
|
raise ValueError(pad_masks.ndim)
|
|
|
|
|
|
|
|
|
|
cumsum = torch.cumsum(att_masks, dim=1)
|
|
|
|
|
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
|
|
|
|
|
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
|
|
|
|
return att_2d_masks & pad_2d_masks
|
|
|
|
|
temperature: float | None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def pad_vector(vector, new_dim):
|
|
|
|
|
@@ -514,10 +444,16 @@ class PaliGemmaWithExpertModel(
|
|
|
|
|
class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
|
|
|
|
"""Core PI0Fast PyTorch model."""
|
|
|
|
|
|
|
|
|
|
def __init__(self, config: PI0FastConfig, rtc_processor: RTCProcessor | None = None):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
config: PI0FastConfig,
|
|
|
|
|
rtc_processor: RTCProcessor | None = None,
|
|
|
|
|
paligemma_tokenizer: AutoTokenizer | None = None,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.config = config
|
|
|
|
|
self.rtc_processor = rtc_processor
|
|
|
|
|
self._paligemma_tokenizer = paligemma_tokenizer
|
|
|
|
|
|
|
|
|
|
paligemma_config = get_gemma_config(config.paligemma_variant)
|
|
|
|
|
action_expert_config = get_gemma_config(config.action_expert_variant)
|
|
|
|
|
@@ -528,19 +464,6 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
|
|
|
|
use_adarms=[False, True],
|
|
|
|
|
precision=config.dtype,
|
|
|
|
|
)
|
|
|
|
|
from transformers import AutoTokenizer
|
|
|
|
|
|
|
|
|
|
self._paligemma_tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
|
"google/paligemma-3b-pt-224",
|
|
|
|
|
trust_remote_code=True,
|
|
|
|
|
)
|
|
|
|
|
# # Apply dtype conversion to FAST layers to match model precision
|
|
|
|
|
# if config.dtype == "bfloat16":
|
|
|
|
|
# self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.bfloat16)
|
|
|
|
|
# self.fast_action_lm_head = self.fast_action_lm_head.to(dtype=torch.bfloat16)
|
|
|
|
|
# elif config.dtype == "float32":
|
|
|
|
|
# self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.float32)
|
|
|
|
|
# self.fast_action_lm_head = self.fast_action_lm_head.to(dtype=torch.float32)
|
|
|
|
|
|
|
|
|
|
# Initialize gradient checkpointing flag
|
|
|
|
|
self.gradient_checkpointing_enabled = False
|
|
|
|
|
@@ -548,8 +471,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
|
|
|
|
# Compile model if requested
|
|
|
|
|
if config.compile_model:
|
|
|
|
|
torch.set_float32_matmul_precision("high")
|
|
|
|
|
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
|
|
|
|
|
# Also compile the main forward pass used during training
|
|
|
|
|
self.sample_actions_fast = torch.compile(self.sample_actions_fast, mode=config.compile_mode)
|
|
|
|
|
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
|
|
|
|
|
|
|
|
|
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
|
|
|
|
|
@@ -578,9 +500,6 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
|
|
|
|
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
|
|
|
|
logging.info("Disabled gradient checkpointing for PI0FastPytorch model")
|
|
|
|
|
|
|
|
|
|
def _rtc_enabled(self):
|
|
|
|
|
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
|
|
|
|
|
|
|
|
|
def _apply_checkpoint(self, func, *args, **kwargs):
|
|
|
|
|
"""Helper method to apply gradient checkpointing if enabled."""
|
|
|
|
|
if self.gradient_checkpointing_enabled and self.training:
|
|
|
|
|
@@ -597,352 +516,6 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
|
|
|
|
result = result.to(dtype=dtype)
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
def _create_custom_attention_mask(self, att_mask_segments, pad_masks, bsize):
|
|
|
|
|
"""Create custom 2D attention mask for the new attention pattern.
|
|
|
|
|
|
|
|
|
|
Attention rules:
|
|
|
|
|
- Images + Language: bidirectional among themselves, don't attend to subtask or FAST
|
|
|
|
|
- Subtask: attend to images + language, causal among themselves, don't attend to FAST
|
|
|
|
|
- FAST: attend to images + language + subtask, causal among themselves
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
att_mask_segments: List of (type, length) tuples
|
|
|
|
|
pad_masks: Padding masks [B, total_seq_len]
|
|
|
|
|
bsize: Batch size
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
att_2d_masks: 2D attention mask [B, total_seq_len, total_seq_len]
|
|
|
|
|
"""
|
|
|
|
|
total_len = sum(length for _, length in att_mask_segments)
|
|
|
|
|
device = pad_masks.device
|
|
|
|
|
|
|
|
|
|
# Initialize attention mask as False (cannot attend)
|
|
|
|
|
att_2d_masks = torch.zeros(bsize, total_len, total_len, dtype=torch.bool, device=device)
|
|
|
|
|
|
|
|
|
|
# Track positions for each segment
|
|
|
|
|
positions = []
|
|
|
|
|
current_pos = 0
|
|
|
|
|
for seg_type, seg_len in att_mask_segments:
|
|
|
|
|
positions.append((seg_type, current_pos, current_pos + seg_len))
|
|
|
|
|
current_pos += seg_len
|
|
|
|
|
|
|
|
|
|
# Apply attention rules
|
|
|
|
|
for _i, (query_type, query_start, query_end) in enumerate(positions):
|
|
|
|
|
for _j, (key_type, key_start, key_end) in enumerate(positions):
|
|
|
|
|
# Images and Language can attend to each other bidirectionally
|
|
|
|
|
if (
|
|
|
|
|
query_type in ["image", "language"]
|
|
|
|
|
and key_type in ["image", "language"]
|
|
|
|
|
or query_type == "subtask"
|
|
|
|
|
and key_type in ["image", "language"]
|
|
|
|
|
):
|
|
|
|
|
att_2d_masks[:, query_start:query_end, key_start:key_end] = True
|
|
|
|
|
|
|
|
|
|
# Subtask tokens attend causally to themselves
|
|
|
|
|
elif query_type == "subtask" and key_type == "subtask":
|
|
|
|
|
# Create causal mask for subtask tokens
|
|
|
|
|
subtask_len = query_end - query_start
|
|
|
|
|
causal_mask = torch.tril(
|
|
|
|
|
torch.ones(subtask_len, subtask_len, dtype=torch.bool, device=device)
|
|
|
|
|
)
|
|
|
|
|
att_2d_masks[:, query_start:query_end, key_start:key_end] = causal_mask[None, :, :]
|
|
|
|
|
|
|
|
|
|
# FAST tokens attend to images + language + subtask
|
|
|
|
|
elif query_type == "fast" and key_type in ["image", "language", "subtask"]:
|
|
|
|
|
att_2d_masks[:, query_start:query_end, key_start:key_end] = True
|
|
|
|
|
|
|
|
|
|
# FAST tokens attend causally to themselves
|
|
|
|
|
elif query_type == "fast" and key_type == "fast":
|
|
|
|
|
fast_len = query_end - query_start
|
|
|
|
|
causal_mask = torch.tril(torch.ones(fast_len, fast_len, dtype=torch.bool, device=device))
|
|
|
|
|
att_2d_masks[:, query_start:query_end, key_start:key_end] = causal_mask[None, :, :]
|
|
|
|
|
|
|
|
|
|
# Apply padding masks
|
|
|
|
|
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
|
|
|
|
att_2d_masks = att_2d_masks & pad_2d_masks
|
|
|
|
|
|
|
|
|
|
return att_2d_masks
|
|
|
|
|
|
|
|
|
|
def visualize_attention_mask(
|
|
|
|
|
self, att_mask_segments, att_2d_masks, save_path, batch_idx=0, dpi=150, max_display_tokens=None
|
|
|
|
|
):
|
|
|
|
|
"""Visualize the attention mask with labeled segments.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
att_mask_segments: List of (type, length) tuples defining the segments
|
|
|
|
|
att_2d_masks: 2D attention mask tensor [B, total_seq_len, total_seq_len]
|
|
|
|
|
save_path: Path where to save the visualization image
|
|
|
|
|
batch_idx: Which batch item to visualize (default: 0)
|
|
|
|
|
dpi: DPI for the saved image (default: 150)
|
|
|
|
|
max_display_tokens: Maximum number of tokens to display (for very long sequences)
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
import matplotlib.patches as mpatches
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
from matplotlib.colors import LinearSegmentedColormap
|
|
|
|
|
except ImportError:
|
|
|
|
|
logging.warning("matplotlib not available, skipping attention mask visualization")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# Extract the mask for the specified batch
|
|
|
|
|
mask = att_2d_masks[batch_idx].cpu().float().numpy()
|
|
|
|
|
|
|
|
|
|
# If sequence is too long, downsample for visualization
|
|
|
|
|
if max_display_tokens is not None and mask.shape[0] > max_display_tokens:
|
|
|
|
|
# Simple downsampling by taking every Nth token
|
|
|
|
|
step = mask.shape[0] // max_display_tokens
|
|
|
|
|
mask = mask[::step, ::step]
|
|
|
|
|
# Adjust segments accordingly
|
|
|
|
|
att_mask_segments = [
|
|
|
|
|
(seg_type, max(1, seg_len // step)) for seg_type, seg_len in att_mask_segments
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# Calculate positions for each segment
|
|
|
|
|
positions = []
|
|
|
|
|
current_pos = 0
|
|
|
|
|
for seg_type, seg_len in att_mask_segments:
|
|
|
|
|
positions.append((seg_type, current_pos, current_pos + seg_len))
|
|
|
|
|
current_pos += seg_len
|
|
|
|
|
|
|
|
|
|
# Create figure
|
|
|
|
|
fig, ax = plt.subplots(figsize=(12, 10))
|
|
|
|
|
|
|
|
|
|
# Create custom colormap: white for False (no attention), blue for True (attention)
|
|
|
|
|
colors = ["white", "#2E86AB"]
|
|
|
|
|
n_bins = 2
|
|
|
|
|
cmap = LinearSegmentedColormap.from_list("attention", colors, N=n_bins)
|
|
|
|
|
|
|
|
|
|
# Display the mask
|
|
|
|
|
im = ax.imshow(mask, cmap=cmap, aspect="auto", interpolation="nearest", vmin=0, vmax=1)
|
|
|
|
|
|
|
|
|
|
# Add colorbar
|
|
|
|
|
cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
|
|
|
|
|
cbar.set_label("Attention Enabled", rotation=270, labelpad=20)
|
|
|
|
|
cbar.set_ticks([0.25, 0.75])
|
|
|
|
|
cbar.set_ticklabels(["No", "Yes"])
|
|
|
|
|
|
|
|
|
|
# Define colors for each segment type
|
|
|
|
|
segment_colors = {"image": "#A23B72", "language": "#F18F01", "subtask": "#C73E1D", "fast": "#6A994E"}
|
|
|
|
|
|
|
|
|
|
# Draw segment boundaries and labels
|
|
|
|
|
for seg_type, start, end in positions:
|
|
|
|
|
color = segment_colors.get(seg_type, "#666666")
|
|
|
|
|
|
|
|
|
|
# Draw vertical lines for columns (keys)
|
|
|
|
|
ax.axvline(x=start - 0.5, color=color, linewidth=2, alpha=0.7)
|
|
|
|
|
ax.axvline(x=end - 0.5, color=color, linewidth=2, alpha=0.7)
|
|
|
|
|
|
|
|
|
|
# Draw horizontal lines for rows (queries)
|
|
|
|
|
ax.axhline(y=start - 0.5, color=color, linewidth=2, alpha=0.7)
|
|
|
|
|
ax.axhline(y=end - 0.5, color=color, linewidth=2, alpha=0.7)
|
|
|
|
|
|
|
|
|
|
# Add labels at the top
|
|
|
|
|
mid_pos = (start + end) / 2
|
|
|
|
|
ax.text(
|
|
|
|
|
mid_pos,
|
|
|
|
|
-mask.shape[0] * 0.02,
|
|
|
|
|
f"{seg_type.upper()}\n({end - start})",
|
|
|
|
|
ha="center",
|
|
|
|
|
va="top",
|
|
|
|
|
fontsize=10,
|
|
|
|
|
fontweight="bold",
|
|
|
|
|
color=color,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Add labels on the left
|
|
|
|
|
ax.text(
|
|
|
|
|
-mask.shape[1] * 0.02,
|
|
|
|
|
mid_pos,
|
|
|
|
|
f"{seg_type.upper()}\n({end - start})",
|
|
|
|
|
ha="right",
|
|
|
|
|
va="center",
|
|
|
|
|
fontsize=10,
|
|
|
|
|
fontweight="bold",
|
|
|
|
|
color=color,
|
|
|
|
|
rotation=0,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Set axis labels
|
|
|
|
|
ax.set_xlabel("Key Position (tokens being attended to)", fontsize=12, fontweight="bold")
|
|
|
|
|
ax.set_ylabel("Query Position (tokens attending)", fontsize=12, fontweight="bold")
|
|
|
|
|
ax.set_title(
|
|
|
|
|
"Attention Mask Pattern\n(White = No Attention, Blue = Attention Allowed)",
|
|
|
|
|
fontsize=14,
|
|
|
|
|
fontweight="bold",
|
|
|
|
|
pad=20,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Create legend for segment types
|
|
|
|
|
legend_patches = []
|
|
|
|
|
attention_rules = {
|
|
|
|
|
"image": "Bidirectional with lang",
|
|
|
|
|
"language": "Bidirectional with images",
|
|
|
|
|
"subtask": "Attends to img+lang, causal self",
|
|
|
|
|
"fast": "Attends to all, causal self",
|
|
|
|
|
}
|
|
|
|
|
for seg_type, color in segment_colors.items():
|
|
|
|
|
if any(seg[0] == seg_type for seg in att_mask_segments):
|
|
|
|
|
rule = attention_rules.get(seg_type, "")
|
|
|
|
|
legend_patches.append(mpatches.Patch(color=color, label=f"{seg_type.upper()}: {rule}"))
|
|
|
|
|
|
|
|
|
|
ax.legend(
|
|
|
|
|
handles=legend_patches, loc="upper right", bbox_to_anchor=(1.15, 1.0), framealpha=0.9, fontsize=9
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Adjust layout and save
|
|
|
|
|
plt.tight_layout()
|
|
|
|
|
|
|
|
|
|
# Ensure the directory exists
|
|
|
|
|
save_path = Path(save_path)
|
|
|
|
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
plt.savefig(save_path, dpi=dpi, bbox_inches="tight")
|
|
|
|
|
plt.close()
|
|
|
|
|
|
|
|
|
|
logging.info(f"Attention mask visualization saved to: {save_path}")
|
|
|
|
|
|
|
|
|
|
def sample_noise(self, shape, device):
|
|
|
|
|
return torch.normal(
|
|
|
|
|
mean=0.0,
|
|
|
|
|
std=1.0,
|
|
|
|
|
size=shape,
|
|
|
|
|
dtype=torch.float32,
|
|
|
|
|
device=device,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def sample_time(self, bsize, device):
|
|
|
|
|
time_beta = sample_beta(
|
|
|
|
|
self.config.time_sampling_beta_alpha, self.config.time_sampling_beta_beta, bsize, device
|
|
|
|
|
)
|
|
|
|
|
time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset
|
|
|
|
|
return time.to(dtype=torch.float32, device=device)
|
|
|
|
|
|
|
|
|
|
def embed_prefix(
|
|
|
|
|
self,
|
|
|
|
|
images,
|
|
|
|
|
img_masks,
|
|
|
|
|
tokens,
|
|
|
|
|
subtask_tokens,
|
|
|
|
|
masks,
|
|
|
|
|
subtask_masks,
|
|
|
|
|
fast_action_tokens=None,
|
|
|
|
|
fast_action_masks=None,
|
|
|
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
|
|
|
|
|
"""Embed images with SigLIP, tokens, and optionally subtask tokens with embedding layer.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
images: List of image tensors
|
|
|
|
|
img_masks: List of image masks
|
|
|
|
|
tokens: Language instruction tokens
|
|
|
|
|
subtask_tokens: Subtask tokens to predict (can be None for inference)
|
|
|
|
|
masks: Attention masks for tokens
|
|
|
|
|
fast_action_tokens: FAST action tokens for auxiliary prediction (can be None) - discrete token IDs
|
|
|
|
|
fast_action_masks: Padding masks for FAST action tokens (can be None)
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
embs: Concatenated embeddings [images, tokens, (subtask_tokens if provided), (fast_action_tokens if provided)]
|
|
|
|
|
pad_masks: Padding masks
|
|
|
|
|
att_masks: Custom 2D attention mask implementing the required pattern
|
|
|
|
|
total_T_images: Total number of image tokens
|
|
|
|
|
num_subtask_embs: Number of subtask token embeddings
|
|
|
|
|
num_fast_embs: Number of FAST action token embeddings
|
|
|
|
|
"""
|
|
|
|
|
embs = []
|
|
|
|
|
pad_masks = []
|
|
|
|
|
att_mask_segments = [] # Store info about each segment for custom mask creation
|
|
|
|
|
total_t_images = 0
|
|
|
|
|
num_subtask_embs = 0
|
|
|
|
|
num_fast_embs = 0
|
|
|
|
|
|
|
|
|
|
# Process images
|
|
|
|
|
for img, img_mask in zip(images, img_masks, strict=True):
|
|
|
|
|
|
|
|
|
|
def image_embed_func(img):
|
|
|
|
|
return self.paligemma_with_expert.embed_image(img)
|
|
|
|
|
|
|
|
|
|
img_emb = self._apply_checkpoint(image_embed_func, img)
|
|
|
|
|
bsize, num_img_embs = img_emb.shape[:2]
|
|
|
|
|
|
|
|
|
|
embs.append(img_emb)
|
|
|
|
|
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
|
|
|
|
|
att_mask_segments.append(("image", num_img_embs))
|
|
|
|
|
total_t_images += num_img_embs
|
|
|
|
|
|
|
|
|
|
# Process language instruction tokens
|
|
|
|
|
def lang_embed_func(tokens):
|
|
|
|
|
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
|
|
|
|
|
lang_emb_dim = lang_emb.shape[-1]
|
|
|
|
|
return lang_emb * math.sqrt(lang_emb_dim)
|
|
|
|
|
|
|
|
|
|
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
|
|
|
|
|
embs.append(lang_emb)
|
|
|
|
|
pad_masks.append(masks)
|
|
|
|
|
|
|
|
|
|
num_lang_embs = lang_emb.shape[1]
|
|
|
|
|
att_mask_segments.append(("language", num_lang_embs))
|
|
|
|
|
|
|
|
|
|
# Process subtask tokens if provided (these are predicted, so use causal masking)
|
|
|
|
|
if subtask_tokens is not None:
|
|
|
|
|
|
|
|
|
|
def subtask_embed_func(subtask_tokens):
|
|
|
|
|
subtask_emb = self.paligemma_with_expert.embed_language_tokens(subtask_tokens)
|
|
|
|
|
subtask_emb_dim = subtask_emb.shape[-1]
|
|
|
|
|
return subtask_emb * math.sqrt(subtask_emb_dim)
|
|
|
|
|
|
|
|
|
|
subtask_emb = self._apply_checkpoint(subtask_embed_func, subtask_tokens)
|
|
|
|
|
embs.append(subtask_emb)
|
|
|
|
|
|
|
|
|
|
# Create subtask pad masks (non-zero tokens are valid)
|
|
|
|
|
pad_masks.append(subtask_masks)
|
|
|
|
|
|
|
|
|
|
num_subtask_embs = subtask_emb.shape[1]
|
|
|
|
|
att_mask_segments.append(("subtask", num_subtask_embs))
|
|
|
|
|
# Process FAST action tokens if provided (these are discrete token IDs)
|
|
|
|
|
if fast_action_tokens is not None:
|
|
|
|
|
|
|
|
|
|
def fast_action_embed_func(fast_action_tokens):
|
|
|
|
|
fast_emb = self.fast_action_embedding(fast_action_tokens)
|
|
|
|
|
fast_emb_dim = fast_emb.shape[-1]
|
|
|
|
|
return fast_emb * math.sqrt(fast_emb_dim)
|
|
|
|
|
|
|
|
|
|
fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens)
|
|
|
|
|
embs.append(fast_action_emb)
|
|
|
|
|
|
|
|
|
|
# Use provided mask or create default (all valid)
|
|
|
|
|
if fast_action_masks is not None:
|
|
|
|
|
fast_pad_mask = fast_action_masks
|
|
|
|
|
else:
|
|
|
|
|
bsize = fast_action_tokens.shape[0]
|
|
|
|
|
num_fast_embs = fast_action_tokens.shape[1]
|
|
|
|
|
fast_pad_mask = torch.ones(
|
|
|
|
|
bsize, num_fast_embs, dtype=torch.bool, device=fast_action_tokens.device
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
num_fast_embs = fast_action_tokens.shape[1]
|
|
|
|
|
pad_masks.append(fast_pad_mask)
|
|
|
|
|
att_mask_segments.append(("fast", num_fast_embs))
|
|
|
|
|
|
|
|
|
|
embs = torch.cat(embs, dim=1)
|
|
|
|
|
pad_masks = torch.cat(pad_masks, dim=1)
|
|
|
|
|
|
|
|
|
|
# Create custom 2D attention mask
|
|
|
|
|
# Attention rules:
|
|
|
|
|
# - Images + Language: bidirectional among themselves, don't attend to subtask or FAST
|
|
|
|
|
# - Subtask: attend to images + language, causal among themselves, don't attend to FAST
|
|
|
|
|
# - FAST: attend to images + language + subtask, causal among themselves
|
|
|
|
|
att_masks = self._create_custom_attention_mask(att_mask_segments, pad_masks, bsize)
|
|
|
|
|
|
|
|
|
|
# # Optionally visualize the attention mask
|
|
|
|
|
# self.visualize_attention_mask(
|
|
|
|
|
# att_mask_segments=att_mask_segments,
|
|
|
|
|
# att_2d_masks=att_masks,
|
|
|
|
|
# save_path="/admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05/attention_mask_visualization.png",
|
|
|
|
|
# batch_idx=0,
|
|
|
|
|
# max_display_tokens=512 # Limit display for very long sequences
|
|
|
|
|
# )
|
|
|
|
|
|
|
|
|
|
return embs, pad_masks, att_masks, total_t_images, num_subtask_embs, num_fast_embs
|
|
|
|
|
|
|
|
|
|
def embed_prefix_fast(
|
|
|
|
|
self,
|
|
|
|
|
images,
|
|
|
|
|
@@ -952,9 +525,8 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
|
|
|
|
fast_action_tokens=None,
|
|
|
|
|
fast_action_masks=None,
|
|
|
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]:
|
|
|
|
|
"""Embed images, language tokens, and FAST action tokens for FAST-only mode.
|
|
|
|
|
"""Embed images, language tokens, and FAST action tokens.
|
|
|
|
|
|
|
|
|
|
This is a simplified version of embed_prefix without subtask tokens.
|
|
|
|
|
Attention pattern:
|
|
|
|
|
- Images + Language: bidirectional among themselves
|
|
|
|
|
- FAST: attend to images + language, causal among themselves
|
|
|
|
|
@@ -1018,23 +590,14 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
|
|
|
|
fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens)
|
|
|
|
|
embs.append(fast_action_emb)
|
|
|
|
|
|
|
|
|
|
if fast_action_masks is not None:
|
|
|
|
|
fast_pad_mask = fast_action_masks
|
|
|
|
|
else:
|
|
|
|
|
bsize = fast_action_tokens.shape[0]
|
|
|
|
|
num_fast_embs = fast_action_tokens.shape[1]
|
|
|
|
|
fast_pad_mask = torch.ones(
|
|
|
|
|
bsize, num_fast_embs, dtype=torch.bool, device=fast_action_tokens.device
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
num_fast_embs = fast_action_tokens.shape[1]
|
|
|
|
|
pad_masks.append(fast_pad_mask)
|
|
|
|
|
pad_masks.append(fast_action_masks)
|
|
|
|
|
att_mask_segments.append(("fast", num_fast_embs))
|
|
|
|
|
|
|
|
|
|
embs = torch.cat(embs, dim=1)
|
|
|
|
|
pad_masks = torch.cat(pad_masks, dim=1)
|
|
|
|
|
|
|
|
|
|
# Create custom 2D attention mask for FAST-only mode:
|
|
|
|
|
# Create custom 2D attention mask:
|
|
|
|
|
# - Images + Language: bidirectional among themselves
|
|
|
|
|
# - FAST: attend to images + language, causal among themselves
|
|
|
|
|
att_masks = self._create_custom_attention_mask_fast(att_mask_segments, pad_masks, bsize)
|
|
|
|
|
@@ -1042,7 +605,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
|
|
|
|
return embs, pad_masks, att_masks, total_t_images, num_fast_embs
|
|
|
|
|
|
|
|
|
|
def _create_custom_attention_mask_fast(self, att_mask_segments, pad_masks, bsize):
|
|
|
|
|
"""Create custom 2D attention mask for FAST-only mode.
|
|
|
|
|
"""Create custom 2D attention mask.
|
|
|
|
|
|
|
|
|
|
Attention rules:
|
|
|
|
|
- Images + Language: bidirectional among themselves
|
|
|
|
|
@@ -1091,7 +654,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
|
|
|
|
fast_action_tokens,
|
|
|
|
|
fast_action_masks,
|
|
|
|
|
) -> dict:
|
|
|
|
|
"""Forward pass for FAST-only mode (no flow matching, no subtask).
|
|
|
|
|
"""Forward pass for PI0Fast.
|
|
|
|
|
|
|
|
|
|
This implements the Pi0FAST training objective: predict next action token
|
|
|
|
|
using cross-entropy loss.
|
|
|
|
|
@@ -1151,11 +714,6 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
|
|
|
|
# only compute logits for the positions that predict FAST tokens
|
|
|
|
|
lm_head = self.paligemma_with_expert.paligemma.lm_head
|
|
|
|
|
|
|
|
|
|
# The FAST tokens start at position (total_T_images + num_lang_tokens)
|
|
|
|
|
# For next-token prediction:
|
|
|
|
|
# - Position (fast_start - 1) in input predicts fast_action_tokens[0]
|
|
|
|
|
# - Position (fast_start) in input predicts fast_action_tokens[1], etc.
|
|
|
|
|
|
|
|
|
|
# Targets are the FAST action tokens
|
|
|
|
|
fast_targets = fast_action_tokens # (B, num_fast_embs)
|
|
|
|
|
|
|
|
|
|
@@ -1169,24 +727,6 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
|
|
|
|
fast_targets = fast_targets[:, 1:] # shift targets right
|
|
|
|
|
fast_action_masks = fast_action_masks[:, 1:] # shift masks to match targets
|
|
|
|
|
|
|
|
|
|
# from transformers import AutoTokenizer
|
|
|
|
|
# self._paligemma_tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
|
# "google/paligemma-3b-pt-224",
|
|
|
|
|
# trust_remote_code=True,
|
|
|
|
|
# add_eos_token=True,
|
|
|
|
|
# add_bos_token=False
|
|
|
|
|
# )
|
|
|
|
|
# # remove
|
|
|
|
|
# decoded_tokens = [
|
|
|
|
|
# self._paligemma_tokenizer.convert_ids_to_tokens(seq.tolist())
|
|
|
|
|
# for seq in fast_targets
|
|
|
|
|
# ]
|
|
|
|
|
# corrected_tokens = [
|
|
|
|
|
# self._paligemma_tokenizer.convert_ids_to_tokens(seq.tolist())
|
|
|
|
|
# for seq in fast_logits_for_pred.argmax(dim=-1)
|
|
|
|
|
# ]
|
|
|
|
|
# breakpoint()
|
|
|
|
|
|
|
|
|
|
# compute cross-entropy loss
|
|
|
|
|
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
|
|
|
|
fast_logits_flat = fast_logits_for_pred.reshape(-1, fast_logits_for_pred.size(-1))
|
|
|
|
|
@@ -1320,38 +860,38 @@ class PI0FastPolicy(PreTrainedPolicy):
|
|
|
|
|
config.validate_features()
|
|
|
|
|
self.config = config
|
|
|
|
|
|
|
|
|
|
# Initialize the core PI0Fast model
|
|
|
|
|
self.init_rtc_processor()
|
|
|
|
|
self.model = PI0FastPytorch(config, rtc_processor=self.rtc_processor)
|
|
|
|
|
|
|
|
|
|
# Enable gradient checkpointing if requested
|
|
|
|
|
if config.gradient_checkpointing:
|
|
|
|
|
self.model.gradient_checkpointing_enable()
|
|
|
|
|
|
|
|
|
|
self.model.to(config.device)
|
|
|
|
|
|
|
|
|
|
# Load FAST tokenizer for action detokenization (only if fast_only mode)
|
|
|
|
|
self.action_tokenizer = None
|
|
|
|
|
self._paligemma_tokenizer = None
|
|
|
|
|
self._fast_skip_tokens = 128
|
|
|
|
|
|
|
|
|
|
# Load tokenizers first
|
|
|
|
|
try:
|
|
|
|
|
from transformers import AutoProcessor, AutoTokenizer
|
|
|
|
|
|
|
|
|
|
# Load FAST tokenizer
|
|
|
|
|
self.action_tokenizer = AutoProcessor.from_pretrained(
|
|
|
|
|
"jadechoghari/fast-libero-tokenizer-mean-std", trust_remote_code=True
|
|
|
|
|
config.action_tokenizer_name, trust_remote_code=True
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Load PaliGemma tokenizer for token conversion
|
|
|
|
|
self._paligemma_tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
|
"google/paligemma-3b-pt-224", trust_remote_code=True, add_eos_token=True, add_bos_token=False
|
|
|
|
|
config.text_tokenizer_name, trust_remote_code=True, add_eos_token=True, add_bos_token=False
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
logging.info("Loaded FAST tokenizer for action detokenization")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logging.warning(f"Could not load FAST tokenizer for action detokenization: {e}")
|
|
|
|
|
logging.warning("Action tokens will be returned without detokenization")
|
|
|
|
|
self._paligemma_tokenizer = None
|
|
|
|
|
self.action_tokenizer = None
|
|
|
|
|
|
|
|
|
|
# Initialize the core PI0Fast model
|
|
|
|
|
self.init_rtc_processor()
|
|
|
|
|
self.model = PI0FastPytorch(
|
|
|
|
|
config, rtc_processor=self.rtc_processor, paligemma_tokenizer=self._paligemma_tokenizer
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Enable gradient checkpointing if requested
|
|
|
|
|
if config.gradient_checkpointing:
|
|
|
|
|
self.model.gradient_checkpointing_enable()
|
|
|
|
|
|
|
|
|
|
self.model.to(config.device)
|
|
|
|
|
|
|
|
|
|
self.reset()
|
|
|
|
|
|
|
|
|
|
@@ -1634,7 +1174,7 @@ class PI0FastPolicy(PreTrainedPolicy):
|
|
|
|
|
Returns:
|
|
|
|
|
Action token IDs
|
|
|
|
|
"""
|
|
|
|
|
return self._paligemma_tokenizer.vocab_size - 1 - self._fast_skip_tokens - tokens
|
|
|
|
|
return self._paligemma_tokenizer.vocab_size - 1 - self.config.fast_skip_tokens - tokens
|
|
|
|
|
|
|
|
|
|
def decode_actions_with_fast(
|
|
|
|
|
self, token_ids: list[int], time_horizon: int, action_dim: int, relaxed_decoding: bool = True
|
|
|
|
|
@@ -1807,9 +1347,9 @@ class PI0FastPolicy(PreTrainedPolicy):
|
|
|
|
|
tokens = batch[f"{OBS_LANGUAGE_TOKENS}"]
|
|
|
|
|
masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
|
|
|
|
|
|
|
|
|
# Get optional parameters
|
|
|
|
|
temperature = kwargs.get("temperature", 0.0)
|
|
|
|
|
max_decoding_steps = 256
|
|
|
|
|
# Get decoding parameters
|
|
|
|
|
temperature = self.config.temperature
|
|
|
|
|
max_decoding_steps = self.config.max_decoding_steps
|
|
|
|
|
|
|
|
|
|
# Sample action tokens autoregressively
|
|
|
|
|
action_tokens = self.model.sample_actions_fast(
|
|
|
|
|
@@ -1820,9 +1360,10 @@ class PI0FastPolicy(PreTrainedPolicy):
|
|
|
|
|
max_decoding_steps=max_decoding_steps,
|
|
|
|
|
temperature=temperature,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Detokenize action tokens to continuous actions
|
|
|
|
|
action_horizon = self.config.n_action_steps
|
|
|
|
|
action_dim = 7
|
|
|
|
|
action_dim = self.config.output_features[ACTION].shape[0]
|
|
|
|
|
|
|
|
|
|
continuous_actions = self.detokenize_actions(
|
|
|
|
|
action_tokens, action_horizon=action_horizon, action_dim=action_dim
|
|
|
|
|
@@ -1837,23 +1378,25 @@ class PI0FastPolicy(PreTrainedPolicy):
|
|
|
|
|
images, img_masks = self._preprocess_images(batch)
|
|
|
|
|
|
|
|
|
|
# Get FAST action tokens from batch
|
|
|
|
|
fast_action_tokens = batch.get("action.tokens") # (B, max_action_tokens)
|
|
|
|
|
fast_action_masks = batch.get("action.token_mask") # (B, max_action_tokens)
|
|
|
|
|
fast_action_tokens = batch.get(ACTION_TOKENS) # (B, max_action_tokens)
|
|
|
|
|
fast_action_masks = batch.get(ACTION_TOKEN_MASK) # (B, max_action_tokens)
|
|
|
|
|
|
|
|
|
|
# Use full language tokens (no separation into high_level_task and subtask)
|
|
|
|
|
tokens = batch[f"{OBS_LANGUAGE_TOKENS}"]
|
|
|
|
|
masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
|
|
|
|
tokens = batch.get(OBS_LANGUAGE_TOKENS)
|
|
|
|
|
masks = batch.get(OBS_LANGUAGE_ATTENTION_MASK)
|
|
|
|
|
|
|
|
|
|
if fast_action_tokens is None or fast_action_masks is None:
|
|
|
|
|
raise ValueError("FAST-only mode requires action.tokens and action.token_mask in the batch")
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"PI0Fast requires {ACTION_TOKENS} and {ACTION_TOKEN_MASK} to be present in the batch"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
loss_dict = self.model.forward(
|
|
|
|
|
images,
|
|
|
|
|
img_masks,
|
|
|
|
|
tokens,
|
|
|
|
|
masks,
|
|
|
|
|
fast_action_tokens=fast_action_tokens,
|
|
|
|
|
fast_action_masks=fast_action_masks,
|
|
|
|
|
fast_action_tokens,
|
|
|
|
|
fast_action_masks,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
loss = loss_dict["loss"]
|
|
|
|
|
|