2024-07-04 13:02:31 +02:00
#!/usr/bin/env python
# Copyright 2024 Seungjae Lee and Yibin Wang and Haritheja Etukuru
# and H. Jin Kim and Nur Muhammad Mahi Shafiullah and Lerrel Pinto
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2024-06-26 16:55:02 +09:00
import math
import warnings
from collections import deque
from typing import Callable , List
import einops
import numpy as np
import torch
import torch . nn . functional as F # noqa: N812
import torchvision
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor , nn
from torch . optim . lr_scheduler import LambdaLR
from lerobot . common . policies . normalize import Normalize , Unnormalize
from lerobot . common . policies . utils import get_device_from_parameters , populate_queues
from lerobot . common . policies . vqbet . configuration_vqbet import VQBeTConfig
from lerobot . common . policies . vqbet . vqbet_utils import GPT , ResidualVQ
# ruff: noqa: N806
class VQBeTPolicy ( nn . Module , PyTorchModelHubMixin ) :
"""
VQ - BeT Policy as per " Behavior Generation with Latent Actions "
"""
name = " vqbet "
def __init__ (
self ,
config : VQBeTConfig | None = None ,
dataset_stats : dict [ str , dict [ str , Tensor ] ] | None = None ,
) :
"""
Args :
config : Policy configuration class instance or None , in which case the default instantiation of
the configuration class is used .
dataset_stats : Dataset statistics to be used for normalization . If not passed here , it is expected
that they will be passed with a call to ` load_state_dict ` before the policy is used .
"""
super ( ) . __init__ ( )
if config is None :
config = VQBeTConfig ( )
self . config = config
self . normalize_inputs = Normalize (
config . input_shapes , config . input_normalization_modes , dataset_stats
)
self . normalize_targets = Normalize (
config . output_shapes , config . output_normalization_modes , dataset_stats
)
self . unnormalize_outputs = Unnormalize (
config . output_shapes , config . output_normalization_modes , dataset_stats
)
self . vqbet = VQBeTModel ( config )
self . expected_image_keys = [ k for k in config . input_shapes if k . startswith ( " observation.image " ) ]
self . reset ( )
def reset ( self ) :
"""
Clear observation and action queues . Should be called on ` env . reset ( ) `
queues are populated during rollout of the policy , they contain the n latest observations and actions
"""
self . _queues = {
" observation.images " : deque ( maxlen = self . config . n_obs_steps ) ,
" observation.state " : deque ( maxlen = self . config . n_obs_steps ) ,
" action " : deque ( maxlen = self . config . action_chunk_size ) ,
}
@torch.no_grad
def select_action ( self , batch : dict [ str , Tensor ] ) - > Tensor :
""" Select a single action given environment observations.
This method wraps ` select_actions ` in order to return one action at a time for execution in the
environment . It works by managing the actions in a queue and only calling ` select_actions ` when the
queue is empty .
"""
batch = self . normalize_inputs ( batch )
batch [ " observation.images " ] = torch . stack ( [ batch [ k ] for k in self . expected_image_keys ] , dim = - 4 )
# Note: It's important that this happens after stacking the images into a single key.
self . _queues = populate_queues ( self . _queues , batch )
if not self . vqbet . action_head . vqvae_model . discretized . item ( ) :
warnings . warn (
" To evaluate in the environment, your VQ-BeT model should contain a pretrained Residual VQ. " ,
stacklevel = 1 ,
)
if len ( self . _queues [ " action " ] ) == 0 :
batch = { k : torch . stack ( list ( self . _queues [ k ] ) , dim = 1 ) for k in batch if k in self . _queues }
actions = self . vqbet ( batch , rollout = True ) [ : , : self . config . action_chunk_size ]
# the dimension of returned action is (batch_size, action_chunk_size, action_dim)
actions = self . unnormalize_outputs ( { " action " : actions } ) [ " action " ]
# since the data in the action queue's dimension is (action_chunk_size, batch_size, action_dim), we transpose the action and fill the queue
self . _queues [ " action " ] . extend ( actions . transpose ( 0 , 1 ) )
action = self . _queues [ " action " ] . popleft ( )
return action
def forward ( self , batch : dict [ str , Tensor ] ) - > dict [ str , Tensor ] :
""" Run the batch through the model and compute the loss for training or validation. """
batch = self . normalize_inputs ( batch )
batch [ " observation.images " ] = torch . stack ( [ batch [ k ] for k in self . expected_image_keys ] , dim = - 4 )
batch = self . normalize_targets ( batch )
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
if not self . vqbet . action_head . vqvae_model . discretized . item ( ) :
# loss: total loss of training RVQ
# n_different_codes: how many of the total possible VQ codes are being used in single batch (how many of them have at least one encoder embedding as a nearest neighbor). This can be at most `vqvae_n_embed * number of layers of RVQ (=2)`.
# n_different_combinations: how many different code combinations are being used out of all possible combinations in single batch. This can be at most `vqvae_n_embed ^ number of layers of RVQ (=2)` (hint consider the RVQ as a decision tree).
loss , n_different_codes , n_different_combinations , recon_l1_error = (
self . vqbet . action_head . discretize ( self . config . n_vqvae_training_steps , batch [ " action " ] )
)
return {
" loss " : loss ,
" n_different_codes " : n_different_codes ,
" n_different_combinations " : n_different_combinations ,
" recon_l1_error " : recon_l1_error ,
}
# if Residual VQ is already trained, VQ-BeT trains its GPT and bin prediction head / offset prediction head parts.
_ , loss_dict = self . vqbet ( batch , rollout = False )
return loss_dict
class SpatialSoftmax ( nn . Module ) :
"""
Spatial Soft Argmax operation described in " Deep Spatial Autoencoders for Visuomotor Learning " by Finn et al .
( https : / / arxiv . org / pdf / 1509.06113 ) . A minimal port of the robomimic implementation .
At a high level , this takes 2 D feature maps ( from a convnet / ViT ) and returns the " center of mass "
of activations of each channel , i . e . , keypoints in the image space for the policy to focus on .
Example : take feature maps of size ( 512 x10x12 ) . We generate a grid of normalized coordinates ( 10 x12x2 ) :
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
| ( - 1. , - 1. ) | ( - 0.82 , - 1. ) | . . . | ( 1. , - 1. ) |
| ( - 1. , - 0.78 ) | ( - 0.82 , - 0.78 ) | . . . | ( 1. , - 0.78 ) |
| . . . | . . . | . . . | . . . |
| ( - 1. , 1. ) | ( - 0.82 , 1. ) | . . . | ( 1. , 1. ) |
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
This is achieved by applying channel - wise softmax over the activations ( 512 x120 ) and computing the dot
product with the coordinates ( 120 x2 ) to get expected points of maximal activation ( 512 x2 ) .
The example above results in 512 keypoints ( corresponding to the 512 input channels ) . We can optionally
provide num_kp != None to control the number of keypoints . This is achieved by a first applying a learnable
linear mapping ( in_channels , H , W ) - > ( num_kp , H , W ) .
"""
def __init__ ( self , input_shape , num_kp = None ) :
"""
Args :
input_shape ( list ) : ( C , H , W ) input feature map shape .
num_kp ( int ) : number of keypoints in output . If None , output will have the same number of channels as input .
"""
super ( ) . __init__ ( )
assert len ( input_shape ) == 3
self . _in_c , self . _in_h , self . _in_w = input_shape
if num_kp is not None :
self . nets = torch . nn . Conv2d ( self . _in_c , num_kp , kernel_size = 1 )
self . _out_c = num_kp
else :
self . nets = None
self . _out_c = self . _in_c
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
# and causes a small degradation in pc_success of pre-trained models.
pos_x , pos_y = np . meshgrid ( np . linspace ( - 1.0 , 1.0 , self . _in_w ) , np . linspace ( - 1.0 , 1.0 , self . _in_h ) )
pos_x = torch . from_numpy ( pos_x . reshape ( self . _in_h * self . _in_w , 1 ) ) . float ( )
pos_y = torch . from_numpy ( pos_y . reshape ( self . _in_h * self . _in_w , 1 ) ) . float ( )
# register as buffer so it's moved to the correct device.
self . register_buffer ( " pos_grid " , torch . cat ( [ pos_x , pos_y ] , dim = 1 ) )
def forward ( self , features : Tensor ) - > Tensor :
"""
Args :
features : ( B , C , H , W ) input feature maps .
Returns :
( B , K , 2 ) image - space coordinates of keypoints .
"""
if self . nets is not None :
features = self . nets ( features )
# [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
features = features . reshape ( - 1 , self . _in_h * self . _in_w )
# 2d softmax normalization
attention = F . softmax ( features , dim = - 1 )
# [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions
expected_xy = attention @ self . pos_grid
# reshape to [B, K, 2]
feature_keypoints = expected_xy . view ( - 1 , self . _out_c , 2 )
return feature_keypoints
class VQBeTModel ( nn . Module ) :
""" VQ-BeT: The underlying neural network for VQ-BeT
Note : In this code we use the terms ` rgb_encoder ` , ' policy ' , ` action_head ` . The meanings are as follows .
- The ` rgb_encoder ` process rgb - style image observations to one - dimensional embedding vectors
- A ` policy ` is a minGPT architecture , that takes observation sequences and action query tokens to generate ` features ` .
- These ` features ` pass through the action head , which passes through the code prediction , offset prediction head ,
and finally generates a prediction for the action chunks .
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - * * legend * * - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
│ n = n_obs_steps , p = n_action_pred_token , c = action_chunk_size ) │
│ o_ { t } : visual observation at timestep { t } │
│ s_ { t } : state observation at timestep { t } │
│ a_ { t } : action at timestep { t } │
│ A_Q : action_query_token │
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
Training Phase 1. Discretize action using Residual VQ ( for config . n_vqvae_training_steps steps )
┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐
│ │ │ │ │ │
│ RVQ encoder │ ─ ► │ Residual │ ─ ► │ RVQ Decoder │
│ ( a_ { t } ~ a_ { t + p } ) │ │ Code Quantizer │ │ │
│ │ │ │ │ │
└ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘ └ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘ └ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘
Training Phase 2.
timestep { t - n + 1 } timestep { t - n + 2 } timestep { t }
┌ ─ ─ ─ ─ ─ ┴ ─ ─ ─ ─ ─ ┐ ┌ ─ ─ ─ ─ ─ ┴ ─ ─ ─ ─ ─ ┐ ┌ ─ ─ ─ ─ ─ ┴ ─ ─ ─ ─ ─ ┐
o_ { t - n + 1 } o_ { t - n + 2 } . . . o_ { t }
│ │ │
│ s_ { t - n + 1 } │ s_ { t - n + 2 } . . . │ s_ { t } p
│ │ │ │ │ │ ┌ ─ ─ ─ ─ ─ ─ ─ ┴ ─ ─ ─ ─ ─ ─ ─ ┐
│ │ A_Q │ │ A_Q . . . │ │ A_Q . . . A_Q
│ │ │ │ │ │ │ │ │ │
┌ ─ ─ ─ ▼ ─ ─ ─ ─ ─ ▼ ─ ─ ─ ─ ─ ▼ ─ ─ ─ ─ ─ ▼ ─ ─ ─ ─ ─ ▼ ─ ─ ─ ─ ─ ▼ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ▼ ─ ─ ─ ─ ─ ▼ ─ ─ ─ ─ ─ ▼ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ▼ ─ ─ ─ ┐
│ │
│ GPT │ = > policy
│ │
└ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ▼ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ▼ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ▼ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ▼ ─ ─ ─ ┘
│ │ │ │
┌ ─ ─ ─ ┴ ─ ─ ─ ┐ ┌ ─ ─ ─ ┴ ─ ─ ─ ┐ ┌ ─ ─ ─ ┴ ─ ─ ─ ┐ ┌ ─ ─ ─ ┴ ─ ─ ─ ┐
code offset code offset code offset code offset
▼ │ ▼ │ ▼ │ ▼ │ = > action_head
RVQ Decoder │ RVQ Decoder │ RVQ Decoder │ RVQ Decoder │
└ ─ ─ + ─ ─ ┘ └ ─ ─ + ─ ─ ┘ └ ─ ─ + ─ ─ ┘ └ ─ ─ + ─ ─ ┘
▼ ▼ ▼ ▼
action chunk action chunk action chunk action chunk
a_ { t - n + 1 } ~ a_ { t - n + 2 } ~ a_ { t } ~ . . . a_ { t + p - 1 } ~
a_ { t - n + c } a_ { t - n + c + 1 } a_ { t + c - 1 } a_ { t + p + c - 1 }
▼
ONLY this chunk is used in rollout !
"""
def __init__ ( self , config : VQBeTConfig ) :
super ( ) . __init__ ( )
self . config = config
self . rgb_encoder = VQBeTRgbEncoder ( config )
self . num_images = len ( [ k for k in config . input_shapes if k . startswith ( " observation.image " ) ] )
# This action query token is used as a prompt for querying action chunks. Please refer to "A_Q" in the image above.
# Note: During the forward pass, this token is repeated as many times as needed. The authors also experimented with initializing the necessary number of tokens independently and observed inferior results.
self . action_token = nn . Parameter ( torch . randn ( 1 , 1 , self . config . gpt_input_dim ) )
# To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT.
self . state_projector = MLP (
config . output_shapes [ " action " ] [ 0 ] , hidden_channels = [ self . config . gpt_input_dim ]
)
self . rgb_feature_projector = MLP (
self . rgb_encoder . feature_dim , hidden_channels = [ self . config . gpt_input_dim ]
)
# GPT part of VQ-BeT
self . policy = GPT ( config )
# bin prediction head / offset prediction head part of VQ-BeT
self . action_head = VQBeTHead ( config )
2024-07-11 01:56:11 +09:00
# Action tokens for: each observation step, the current action token, and all future action tokens.
num_tokens = self . config . n_action_pred_token + self . config . n_obs_steps - 1
2024-06-26 16:55:02 +09:00
self . register_buffer (
" select_target_actions_indices " ,
torch . row_stack ( [ torch . arange ( i , i + self . config . action_chunk_size ) for i in range ( num_tokens ) ] ) ,
)
def forward ( self , batch : dict [ str , Tensor ] , rollout : bool ) - > Tensor :
# Input validation.
assert set ( batch ) . issuperset ( { " observation.state " , " observation.images " } )
batch_size , n_obs_steps = batch [ " observation.state " ] . shape [ : 2 ]
assert n_obs_steps == self . config . n_obs_steps
# Extract image feature (first combine batch and sequence dims).
img_features = self . rgb_encoder (
einops . rearrange ( batch [ " observation.images " ] , " b s n ... -> (b s n) ... " )
)
# Separate batch and sequence dims.
img_features = einops . rearrange (
img_features , " (b s n) ... -> b s n ... " , b = batch_size , s = n_obs_steps , n = self . num_images
)
# Arrange prior and current observation step tokens as shown in the class docstring.
# First project features to token dimension.
rgb_tokens = self . rgb_feature_projector (
img_features
) # (batch, obs_step, number of different cameras, projection dims)
input_tokens = [ rgb_tokens [ : , : , i ] for i in range ( rgb_tokens . size ( 2 ) ) ]
input_tokens . append (
self . state_projector ( batch [ " observation.state " ] )
) # (batch, obs_step, projection dims)
input_tokens . append ( einops . repeat ( self . action_token , " 1 1 d -> b n d " , b = batch_size , n = n_obs_steps ) )
# Interleave tokens by stacking and rearranging.
input_tokens = torch . stack ( input_tokens , dim = 2 )
input_tokens = einops . rearrange ( input_tokens , " b n t d -> b (n t) d " )
len_additional_action_token = self . config . n_action_pred_token - 1
future_action_tokens = self . action_token . repeat ( batch_size , len_additional_action_token , 1 )
# add additional action query tokens for predicting future action chunks
input_tokens = torch . cat ( [ input_tokens , future_action_tokens ] , dim = 1 )
# get action features (pass through GPT)
features = self . policy ( input_tokens )
# len(self.config.input_shapes) is the number of different observation modes. this line gets the index of action prompt tokens.
historical_act_pred_index = np . arange ( 0 , n_obs_steps ) * ( len ( self . config . input_shapes ) + 1 ) + len (
self . config . input_shapes
)
# only extract the output tokens at the position of action query:
# Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models, mapping sequential observation to sequential action (please refer to section 2.2 in BeT paper https://arxiv.org/pdf/2206.11251).
# Thus, it predict historical action sequence, in addition to current and future actions (predicting future actions : optional).
features = torch . cat (
[ features [ : , historical_act_pred_index ] , features [ : , - len_additional_action_token : ] ] , dim = 1
)
# pass through action head
action_head_output = self . action_head ( features )
# if rollout, VQ-BeT don't calculate loss
if rollout :
return action_head_output [ " predicted_action " ] [ : , n_obs_steps - 1 , : ] . reshape (
batch_size , self . config . action_chunk_size , - 1
)
# else, it calculate overall loss (bin prediction loss, and offset loss)
else :
output = batch [ " action " ] [ : , self . select_target_actions_indices ]
loss = self . action_head . loss_fn ( action_head_output , output , reduction = " mean " )
return action_head_output , loss
class VQBeTHead ( nn . Module ) :
def __init__ ( self , config : VQBeTConfig ) :
"""
VQBeTHead takes output of GPT layers , and pass the feature through bin prediction head ( ` self . map_to_cbet_preds_bin ` ) , and offset prediction head ( ` self . map_to_cbet_preds_offset ` )
self . map_to_cbet_preds_bin : outputs probability of each code ( for each layer ) .
The input dimension of ` self . map_to_cbet_preds_bin ` is same with the output of GPT ,
and the output dimension of ` self . map_to_cbet_preds_bin ` is ` self . vqvae_model . vqvae_num_layers ( = fixed as 2 ) * self . config . vqvae_n_embed ` .
if the agent select the code sequentially , we use self . map_to_cbet_preds_primary_bin and self . map_to_cbet_preds_secondary_bin instead of self . _map_to_cbet_preds_bin .
self . map_to_cbet_preds_offset : output the predicted offsets for all the codes in all the layers .
The input dimension of ` self . map_to_cbet_preds_offset ` is same with the output of GPT ,
and the output dimension of ` self . map_to_cbet_preds_offset ` is ` self . vqvae_model . vqvae_num_layers ( = fixed as 2 ) * self . config . vqvae_n_embed * config . action_chunk_size * config . output_shapes [ " action " ] [ 0 ] ` .
"""
super ( ) . __init__ ( )
self . config = config
# init vqvae
self . vqvae_model = VqVae ( config )
if config . sequentially_select :
self . map_to_cbet_preds_primary_bin = MLP (
in_channels = config . gpt_output_dim ,
hidden_channels = [ self . config . vqvae_n_embed ] ,
)
self . map_to_cbet_preds_secondary_bin = MLP (
in_channels = config . gpt_output_dim + self . config . vqvae_n_embed ,
hidden_channels = [ self . config . vqvae_n_embed ] ,
)
else :
self . map_to_cbet_preds_bin = MLP (
in_channels = config . gpt_output_dim ,
hidden_channels = [ self . vqvae_model . vqvae_num_layers * self . config . vqvae_n_embed ] ,
)
self . map_to_cbet_preds_offset = MLP (
in_channels = config . gpt_output_dim ,
hidden_channels = [
self . vqvae_model . vqvae_num_layers
* self . config . vqvae_n_embed
* config . action_chunk_size
* config . output_shapes [ " action " ] [ 0 ] ,
] ,
)
# loss
self . _focal_loss_fn = FocalLoss ( gamma = 2.0 )
def discretize ( self , n_vqvae_training_steps , actions ) :
# Resize the action sequence data to fit the action chunk size using a sliding window approach.
actions = torch . cat (
[
actions [ : , j : j + self . config . action_chunk_size , : ]
for j in range ( actions . shape [ 1 ] + 1 - self . config . action_chunk_size )
] ,
dim = 0 ,
)
# `actions` is a tensor of shape (new_batch, action_chunk_size, action_dim) where new_batch is the number of possible chunks created from the original sequences using the sliding window.
loss , metric = self . vqvae_model . vqvae_forward ( actions )
n_different_codes = sum (
[ len ( torch . unique ( metric [ 2 ] [ : , i ] ) ) for i in range ( self . vqvae_model . vqvae_num_layers ) ]
)
n_different_combinations = len ( torch . unique ( metric [ 2 ] , dim = 0 ) )
recon_l1_error = metric [ 0 ] . detach ( ) . cpu ( ) . item ( )
self . vqvae_model . optimized_steps + = 1
# if we updated RVQ more than `n_vqvae_training_steps` steps, we freeze the RVQ part.
if self . vqvae_model . optimized_steps > = n_vqvae_training_steps :
self . vqvae_model . discretized = torch . tensor ( True )
self . vqvae_model . vq_layer . freeze_codebook = torch . tensor ( True )
print ( " Finished discretizing action data! " )
self . vqvae_model . eval ( )
for param in self . vqvae_model . vq_layer . parameters ( ) :
param . requires_grad = False
return loss , n_different_codes , n_different_combinations , recon_l1_error
def forward ( self , x , * * kwargs ) :
# N is the batch size, and T is number of action query tokens, which are process through same GPT
N , T , _ = x . shape
# we calculate N and T side parallely. Thus, the dimensions would be
# (batch size * number of action query tokens, action chunk size, action dimension)
x = einops . rearrange ( x , " N T WA -> (N T) WA " )
# sample offsets
cbet_offsets = self . map_to_cbet_preds_offset ( x )
cbet_offsets = einops . rearrange (
cbet_offsets ,
" (NT) (G C WA) -> (NT) G C WA " ,
G = self . vqvae_model . vqvae_num_layers ,
C = self . config . vqvae_n_embed ,
)
# if self.config.sequentially_select is True, bin prediction head first sample the primary code, and then sample secondary code
if self . config . sequentially_select :
cbet_primary_logits = self . map_to_cbet_preds_primary_bin ( x )
# select primary bin first
cbet_primary_probs = torch . softmax (
cbet_primary_logits / self . config . bet_softmax_temperature , dim = - 1
)
NT , choices = cbet_primary_probs . shape
sampled_primary_centers = einops . rearrange (
torch . multinomial ( cbet_primary_probs . view ( - 1 , choices ) , num_samples = 1 ) ,
" (NT) 1 -> NT " ,
NT = NT ,
)
cbet_secondary_logits = self . map_to_cbet_preds_secondary_bin (
torch . cat (
( x , F . one_hot ( sampled_primary_centers , num_classes = self . config . vqvae_n_embed ) ) ,
axis = 1 ,
)
)
cbet_secondary_probs = torch . softmax (
cbet_secondary_logits / self . config . bet_softmax_temperature , dim = - 1
)
sampled_secondary_centers = einops . rearrange (
torch . multinomial ( cbet_secondary_probs . view ( - 1 , choices ) , num_samples = 1 ) ,
" (NT) 1 -> NT " ,
NT = NT ,
)
sampled_centers = torch . stack ( ( sampled_primary_centers , sampled_secondary_centers ) , axis = 1 )
cbet_logits = torch . stack ( [ cbet_primary_logits , cbet_secondary_logits ] , dim = 1 )
# if self.config.sequentially_select is False, bin prediction head samples primary and secondary code at once.
else :
cbet_logits = self . map_to_cbet_preds_bin ( x )
cbet_logits = einops . rearrange (
cbet_logits , " (NT) (G C) -> (NT) G C " , G = self . vqvae_model . vqvae_num_layers
)
cbet_probs = torch . softmax ( cbet_logits / self . config . bet_softmax_temperature , dim = - 1 )
NT , G , choices = cbet_probs . shape
sampled_centers = einops . rearrange (
torch . multinomial ( cbet_probs . view ( - 1 , choices ) , num_samples = 1 ) ,
" (NT G) 1 -> NT G " ,
NT = NT ,
)
device = get_device_from_parameters ( self )
indices = (
torch . arange ( NT , device = device ) . unsqueeze ( 1 ) ,
torch . arange ( self . vqvae_model . vqvae_num_layers , device = device ) . unsqueeze ( 0 ) ,
sampled_centers ,
)
# Use advanced indexing to sample the values (Extract the only offsets corresponding to the sampled codes.)
sampled_offsets = cbet_offsets [ indices ]
# Then, sum the offsets over the RVQ layers to get a net offset for the bin prediction
sampled_offsets = sampled_offsets . sum ( dim = 1 )
with torch . no_grad ( ) :
# Get the centroids (= vectors corresponding to the codes) of each layer to pass it through RVQ decoder
return_decoder_input = self . vqvae_model . get_embeddings_from_code ( sampled_centers ) . clone ( ) . detach ( )
# pass the centroids through decoder to get actions.
decoded_action = self . vqvae_model . get_action_from_latent ( return_decoder_input ) . clone ( ) . detach ( )
# reshaped extracted offset to match with decoded centroids
sampled_offsets = einops . rearrange (
sampled_offsets , " NT (W A) -> NT W A " , W = self . config . action_chunk_size
)
# add offset and decoded centroids
predicted_action = decoded_action + sampled_offsets
predicted_action = einops . rearrange (
predicted_action ,
" (N T) W A -> N T (W A) " ,
N = N ,
T = T ,
W = self . config . action_chunk_size ,
)
return {
" cbet_logits " : cbet_logits ,
" predicted_action " : predicted_action ,
" sampled_centers " : sampled_centers ,
" decoded_action " : decoded_action ,
}
def loss_fn ( self , pred , target , * * kwargs ) :
"""
for given ground truth action values ( target ) , and prediction ( pred ) this function calculates the overall loss .
predicted_action : predicted action chunk ( offset + decoded centroids )
sampled_centers : sampled centroids ( code of RVQ )
decoded_action : decoded action , which is produced by passing sampled_centers through RVQ decoder
NT : batch size * T
T : number of action query tokens , which are process through same GPT
cbet_logits : probability of all codes in each layer
"""
action_seq = target
predicted_action = pred [ " predicted_action " ]
sampled_centers = pred [ " sampled_centers " ]
decoded_action = pred [ " decoded_action " ]
NT = predicted_action . shape [ 0 ] * predicted_action . shape [ 1 ]
cbet_logits = pred [ " cbet_logits " ]
predicted_action = einops . rearrange (
predicted_action , " N T (W A) -> (N T) W A " , W = self . config . action_chunk_size
)
action_seq = einops . rearrange ( action_seq , " N T W A -> (N T) W A " )
# Figure out the loss for the actions.
# First, we need to find the closest cluster center for each ground truth action.
with torch . no_grad ( ) :
state_vq , action_bins = self . vqvae_model . get_code ( action_seq ) # action_bins: NT, G
# Now we can compute the loss.
# offset loss is L1 distance between the predicted action and ground truth action
offset_loss = F . l1_loss ( action_seq , predicted_action )
# calculate primary code prediction loss
cbet_loss1 = self . _focal_loss_fn (
cbet_logits [ : , 0 , : ] ,
action_bins [ : , 0 ] ,
)
# calculate secondary code prediction loss
cbet_loss2 = self . _focal_loss_fn (
cbet_logits [ : , 1 , : ] ,
action_bins [ : , 1 ] ,
)
# add all the prediction loss
cbet_loss = (
cbet_loss1 * self . config . primary_code_loss_weight
+ cbet_loss2 * self . config . secondary_code_loss_weight
)
equal_primary_code_rate = torch . sum ( ( action_bins [ : , 0 ] == sampled_centers [ : , 0 ] ) . int ( ) ) / ( NT )
equal_secondary_code_rate = torch . sum ( ( action_bins [ : , 1 ] == sampled_centers [ : , 1 ] ) . int ( ) ) / ( NT )
action_mse_error = torch . mean ( ( action_seq - predicted_action ) * * 2 )
vq_action_error = torch . mean ( torch . abs ( action_seq - decoded_action ) )
offset_action_error = torch . mean ( torch . abs ( action_seq - predicted_action ) )
action_error_max = torch . max ( torch . abs ( action_seq - predicted_action ) )
loss = cbet_loss + self . config . offset_loss_weight * offset_loss
loss_dict = {
" loss " : loss ,
" classification_loss " : cbet_loss . detach ( ) . cpu ( ) . item ( ) ,
" offset_loss " : offset_loss . detach ( ) . cpu ( ) . item ( ) ,
" equal_primary_code_rate " : equal_primary_code_rate . detach ( ) . cpu ( ) . item ( ) ,
" equal_secondary_code_rate " : equal_secondary_code_rate . detach ( ) . cpu ( ) . item ( ) ,
" vq_action_error " : vq_action_error . detach ( ) . cpu ( ) . item ( ) ,
" offset_action_error " : offset_action_error . detach ( ) . cpu ( ) . item ( ) ,
" action_error_max " : action_error_max . detach ( ) . cpu ( ) . item ( ) ,
" action_mse_error " : action_mse_error . detach ( ) . cpu ( ) . item ( ) ,
}
return loss_dict
class VQBeTOptimizer ( torch . optim . Adam ) :
def __init__ ( self , policy , cfg ) :
vqvae_params = (
list ( policy . vqbet . action_head . vqvae_model . encoder . parameters ( ) )
+ list ( policy . vqbet . action_head . vqvae_model . decoder . parameters ( ) )
+ list ( policy . vqbet . action_head . vqvae_model . vq_layer . parameters ( ) )
)
decay_params , no_decay_params = policy . vqbet . policy . configure_parameters ( )
decay_params = (
decay_params
+ list ( policy . vqbet . rgb_encoder . parameters ( ) )
+ list ( policy . vqbet . state_projector . parameters ( ) )
+ list ( policy . vqbet . rgb_feature_projector . parameters ( ) )
+ [ policy . vqbet . action_token ]
+ list ( policy . vqbet . action_head . map_to_cbet_preds_offset . parameters ( ) )
)
if cfg . policy . sequentially_select :
decay_params = (
decay_params
+ list ( policy . vqbet . action_head . map_to_cbet_preds_primary_bin . parameters ( ) )
+ list ( policy . vqbet . action_head . map_to_cbet_preds_secondary_bin . parameters ( ) )
)
else :
decay_params = decay_params + list ( policy . vqbet . action_head . map_to_cbet_preds_bin . parameters ( ) )
optim_groups = [
{
" params " : decay_params ,
" weight_decay " : cfg . training . adam_weight_decay ,
" lr " : cfg . training . lr ,
} ,
{
" params " : vqvae_params ,
" weight_decay " : 0.0001 ,
" lr " : cfg . training . vqvae_lr ,
} ,
{
" params " : no_decay_params ,
" weight_decay " : 0.0 ,
" lr " : cfg . training . lr ,
} ,
]
super ( ) . __init__ (
optim_groups ,
cfg . training . lr ,
cfg . training . adam_betas ,
cfg . training . adam_eps ,
)
class VQBeTScheduler ( nn . Module ) :
def __init__ ( self , optimizer , cfg ) :
super ( ) . __init__ ( )
n_vqvae_training_steps = cfg . training . n_vqvae_training_steps
num_warmup_steps = cfg . training . lr_warmup_steps
num_training_steps = cfg . training . offline_steps
num_cycles = 0.5
def lr_lambda ( current_step ) :
if current_step < n_vqvae_training_steps :
return float ( 1 )
else :
current_step = current_step - n_vqvae_training_steps
if current_step < num_warmup_steps :
return float ( current_step ) / float ( max ( 1 , num_warmup_steps ) )
progress = float ( current_step - num_warmup_steps ) / float (
max ( 1 , num_training_steps - num_warmup_steps )
)
return max ( 0.0 , 0.5 * ( 1.0 + math . cos ( math . pi * float ( num_cycles ) * 2.0 * progress ) ) )
self . lr_scheduler = LambdaLR ( optimizer , lr_lambda , - 1 )
def step ( self ) :
self . lr_scheduler . step ( )
class VQBeTRgbEncoder ( nn . Module ) :
""" Encode an RGB image into a 1D feature vector.
Includes the ability to normalize and crop the image first .
Same with DiffusionRgbEncoder from modeling_diffusion . py
"""
def __init__ ( self , config : VQBeTConfig ) :
super ( ) . __init__ ( )
# Set up optional preprocessing.
if config . crop_shape is not None :
self . do_crop = True
# Always use center crop for eval
self . center_crop = torchvision . transforms . CenterCrop ( config . crop_shape )
if config . crop_is_random :
self . maybe_random_crop = torchvision . transforms . RandomCrop ( config . crop_shape )
else :
self . maybe_random_crop = self . center_crop
else :
self . do_crop = False
# Set up backbone.
backbone_model = getattr ( torchvision . models , config . vision_backbone ) (
weights = config . pretrained_backbone_weights
)
# Note: This assumes that the layer4 feature map is children()[-3]
# TODO(alexander-soare): Use a safer alternative.
self . backbone = nn . Sequential ( * ( list ( backbone_model . children ( ) ) [ : - 2 ] ) )
if config . use_group_norm :
if config . pretrained_backbone_weights :
raise ValueError (
" You can ' t replace BatchNorm in a pretrained model without ruining the weights! "
)
self . backbone = _replace_submodules (
root_module = self . backbone ,
predicate = lambda x : isinstance ( x , nn . BatchNorm2d ) ,
func = lambda x : nn . GroupNorm ( num_groups = x . num_features / / 16 , num_channels = x . num_features ) ,
)
# Set up pooling and final layers.
# Use a dry run to get the feature map shape.
# The dummy input should take the number of image channels from `config.input_shapes` and it should
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
# height and width from `config.input_shapes`.
image_keys = [ k for k in config . input_shapes if k . startswith ( " observation.image " ) ]
assert len ( image_keys ) == 1
image_key = image_keys [ 0 ]
dummy_input_h_w = (
config . crop_shape if config . crop_shape is not None else config . input_shapes [ image_key ] [ 1 : ]
)
dummy_input = torch . zeros ( size = ( 1 , config . input_shapes [ image_key ] [ 0 ] , * dummy_input_h_w ) )
with torch . inference_mode ( ) :
dummy_feature_map = self . backbone ( dummy_input )
feature_map_shape = tuple ( dummy_feature_map . shape [ 1 : ] )
self . pool = SpatialSoftmax ( feature_map_shape , num_kp = config . spatial_softmax_num_keypoints )
self . feature_dim = config . spatial_softmax_num_keypoints * 2
self . out = nn . Linear ( config . spatial_softmax_num_keypoints * 2 , self . feature_dim )
self . relu = nn . ReLU ( )
def forward ( self , x : Tensor ) - > Tensor :
"""
Args :
x : ( B , C , H , W ) image tensor with pixel values in [ 0 , 1 ] .
Returns :
( B , D ) image feature .
"""
# Preprocess: maybe crop (if it was set up in the __init__).
if self . do_crop :
if self . training : # noqa: SIM108
x = self . maybe_random_crop ( x )
else :
# Always use center crop for eval.
x = self . center_crop ( x )
# Extract backbone feature.
x = torch . flatten ( self . pool ( self . backbone ( x ) ) , start_dim = 1 )
# Final linear layer with non-linearity.
x = self . relu ( self . out ( x ) )
return x
def _replace_submodules (
root_module : nn . Module , predicate : Callable [ [ nn . Module ] , bool ] , func : Callable [ [ nn . Module ] , nn . Module ]
) - > nn . Module :
"""
Args :
root_module : The module for which the submodules need to be replaced
predicate : Takes a module as an argument and must return True if the that module is to be replaced .
func : Takes a module as an argument and returns a new module to replace it with .
Returns :
The root module with its submodules replaced .
"""
if predicate ( root_module ) :
return func ( root_module )
replace_list = [ k . split ( " . " ) for k , m in root_module . named_modules ( remove_duplicate = True ) if predicate ( m ) ]
for * parents , k in replace_list :
parent_module = root_module
if len ( parents ) > 0 :
parent_module = root_module . get_submodule ( " . " . join ( parents ) )
if isinstance ( parent_module , nn . Sequential ) :
src_module = parent_module [ int ( k ) ]
else :
src_module = getattr ( parent_module , k )
tgt_module = func ( src_module )
if isinstance ( parent_module , nn . Sequential ) :
parent_module [ int ( k ) ] = tgt_module
else :
setattr ( parent_module , k , tgt_module )
# verify that all BN are replaced
assert not any ( predicate ( m ) for _ , m in root_module . named_modules ( remove_duplicate = True ) )
return root_module
class VqVae ( nn . Module ) :
def __init__ (
self ,
config : VQBeTConfig ,
) :
"""
VQ - VAE is composed of three parts : encoder , vq_layer , and decoder .
Encoder and decoder are MLPs consisting of an input , output layer , and hidden layer , respectively .
The vq_layer uses residual VQs .
This class contains functions for training the encoder and decoder along with the residual VQ layer ( for trainign phase 1 ) ,
as well as functions to help BeT training part in training phase 2.
"""
super ( ) . __init__ ( )
self . config = config
# 'discretized' indicates whether the Residual VQ part is trained or not. (After finishing the training, we set discretized=True)
self . register_buffer ( " discretized " , torch . tensor ( False ) )
self . optimized_steps = 0
# we use the fixed number of layers for Residual VQ across all environments.
self . vqvae_num_layers = 2
self . vq_layer = ResidualVQ (
dim = config . vqvae_embedding_dim ,
num_quantizers = self . vqvae_num_layers ,
codebook_size = config . vqvae_n_embed ,
)
self . encoder = MLP (
in_channels = self . config . output_shapes [ " action " ] [ 0 ] * self . config . action_chunk_size ,
hidden_channels = [
config . vqvae_enc_hidden_dim ,
config . vqvae_enc_hidden_dim ,
config . vqvae_embedding_dim ,
] ,
)
self . decoder = MLP (
in_channels = config . vqvae_embedding_dim ,
hidden_channels = [
config . vqvae_enc_hidden_dim ,
config . vqvae_enc_hidden_dim ,
self . config . output_shapes [ " action " ] [ 0 ] * self . config . action_chunk_size ,
] ,
)
def get_embeddings_from_code ( self , encoding_indices ) :
# This function gets code indices as inputs, and outputs embedding vectors corresponding to the code indices.
with torch . no_grad ( ) :
z_embed = self . vq_layer . get_codebook_vector_from_indices ( encoding_indices )
# since the RVQ has multiple layers, it adds the vectors in the axis of layers to provide a vector for that code combination.
z_embed = z_embed . sum ( dim = 0 )
return z_embed
def get_action_from_latent ( self , latent ) :
# given latent vector, this function outputs the decoded action.
output = self . decoder ( latent )
if self . config . action_chunk_size == 1 :
return einops . rearrange ( output , " N (T A) -> N T A " , A = self . config . output_shapes [ " action " ] [ 0 ] )
else :
return einops . rearrange ( output , " N (T A) -> N T A " , A = self . config . output_shapes [ " action " ] [ 0 ] )
def get_code ( self , state ) :
# in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://arxiv.org/pdf/2403.03181)
# this function outputs the `GT code` of given action using frozen encoder and quantization layers. (please refer to Figure 2. in the paper https://arxiv.org/pdf/2403.03181)
state = einops . rearrange ( state , " N T A -> N (T A) " )
with torch . no_grad ( ) :
state_rep = self . encoder ( state )
state_rep_shape = state_rep . shape [ : - 1 ]
state_rep_flat = state_rep . view ( state_rep . size ( 0 ) , - 1 , state_rep . size ( 1 ) )
state_rep_flat , vq_code , vq_loss_state = self . vq_layer ( state_rep_flat )
state_vq = state_rep_flat . view ( * state_rep_shape , - 1 )
vq_code = vq_code . view ( * state_rep_shape , - 1 )
vq_loss_state = torch . sum ( vq_loss_state )
return state_vq , vq_code
def vqvae_forward ( self , state ) :
# This function passes the given data through Residual VQ with Encoder and Decoder. Please refer to section 3.2 in the paper https://arxiv.org/pdf/2403.03181).
state = einops . rearrange ( state , " N T A -> N (T A) " )
# We start with passing action (or action chunk) at:t+n through the encoder ϕ.
state_rep = self . encoder ( state )
state_rep_shape = state_rep . shape [ : - 1 ]
state_rep_flat = state_rep . view ( state_rep . size ( 0 ) , - 1 , state_rep . size ( 1 ) )
# The resulting latent embedding vector x = ϕ(at:t+n) is then mapped to an embedding vector in the codebook of the RVQ layers by the nearest neighbor look-up.
state_rep_flat , vq_code , vq_loss_state = self . vq_layer ( state_rep_flat )
state_vq = state_rep_flat . view ( * state_rep_shape , - 1 )
vq_code = vq_code . view ( * state_rep_shape , - 1 )
# since the RVQ has multiple layers, it adds the vectors in the axis of layers to provide a vector for that code combination.
vq_loss_state = torch . sum ( vq_loss_state )
# Then, the discretized vector zq(x) is reconstructed as ψ(zq(x)) by passing through the decoder ψ.
dec_out = self . decoder ( state_vq )
# Calculate L1 reconstruction loss
encoder_loss = ( state - dec_out ) . abs ( ) . mean ( )
# add encoder reconstruction loss and commitment loss
rep_loss = encoder_loss + vq_loss_state * 5
metric = (
encoder_loss . clone ( ) . detach ( ) ,
vq_loss_state . clone ( ) . detach ( ) ,
vq_code ,
rep_loss . item ( ) ,
)
return rep_loss , metric
class FocalLoss ( nn . Module ) :
"""
From https : / / github . com / notmahi / miniBET / blob / main / behavior_transformer / bet . py
"""
def __init__ ( self , gamma : float = 0 , size_average : bool = True ) :
super ( ) . __init__ ( )
self . gamma = gamma
self . size_average = size_average
def forward ( self , input , target ) :
if len ( input . shape ) == 3 :
N , T , _ = input . shape
logpt = F . log_softmax ( input , dim = - 1 )
logpt = logpt . gather ( - 1 , target . view ( N , T , 1 ) ) . view ( N , T )
elif len ( input . shape ) == 2 :
logpt = F . log_softmax ( input , dim = - 1 )
logpt = logpt . gather ( - 1 , target . view ( - 1 , 1 ) ) . view ( - 1 )
pt = logpt . exp ( )
loss = - 1 * ( 1 - pt ) * * self . gamma * logpt
if self . size_average :
return loss . mean ( )
else :
return loss . sum ( )
class MLP ( torch . nn . Sequential ) :
def __init__ (
self ,
in_channels : int ,
hidden_channels : List [ int ] ,
) :
layers = [ ]
in_dim = in_channels
for hidden_dim in hidden_channels [ : - 1 ] :
layers . append ( torch . nn . Linear ( in_dim , hidden_dim ) )
layers . append ( torch . nn . ReLU ( ) )
in_dim = hidden_dim
layers . append ( torch . nn . Linear ( in_dim , hidden_channels [ - 1 ] ) )
super ( ) . __init__ ( * layers )