mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 04:41:24 +00:00
Reward models refactor (#3142)
* feat(rewards): add RewardModelConfig and PreTrainedRewardModel base classes * refactor(rewards): migrate Classifier from policies/sac/reward_model/ to rewards/classifier/ * refactor(rewards): migrate SARM from policies/sarm/ to rewards/sarm/ * refactor(rewards): add rewards/factory.py and remove reward model code from policies/factory.py * refactor(rewards): update imports and delete old reward model locations * test(rewards): add reward model tests and update existing test imports * fix(rewards): restore full Classifier and SARM implementations * test(rewards): restore missing CUDA and mixed precision classifier processor tests * refactor(lerobot_train.py): remove rabc specific configuration and replace it with a generic samplerweight class in lerobot_train * refactor(lerobot_train.py): add missing sampling weight script * linter + missing files * add testing for sampl weighter * revert some useless changes, improve typing * update docs * add automatic detection of the progress path * remove type exp * improve comment * fix: move rabc.py to rewards/sarm/ and update import paths * refactor(imports): update reward model imports to new module structure * refactor(imports): update reward model imports to reflect new module structure * refactor(imports): conditionally import pandas based on availability * feat(configs): add reward_model field to TrainPipelineConfig and Hub fields to RewardModelConfig * refactor(policies): remove reward model branches from policy factory and __init__ * refactor(rewards): expand __init__ facade and fix SARMConfig __post_init__ crash * feat(train): route reward model training through rewards/factory instead of policies/factory * refactor(train): streamline reward model training logic * fix(rewards): ensure FileNotFoundError is raised for missing config_file * refactor(train): update __get_path_fields__ to include reward_model for config loading * refactor(classifier): remove redundant input normalization in predict_reward method * fix(train): raise ValueError for non-trainable reward models in train function * refactor(pretrained_rm): add model card template * refactor(tests): reward models * refactor(sarm): update reset method and remove unused action prediction methods * refactor(wandb): differentiate tags for reward model and policy training in cfg_to_group function * fix(train): raise ValueError for PEFT usage in reward model training * refactor(rewards): enhance RewardModelConfig with device handling and delta indices properties --------- Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
287
src/lerobot/rewards/classifier/modeling_classifier.py
Normal file
287
src/lerobot/rewards/classifier/modeling_classifier.py
Normal file
@@ -0,0 +1,287 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.rewards.pretrained import PreTrainedRewardModel
|
||||
from lerobot.utils.constants import OBS_IMAGE, REWARD
|
||||
|
||||
|
||||
class ClassifierOutput:
|
||||
"""Wrapper for classifier outputs with additional metadata."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logits: Tensor,
|
||||
probabilities: Tensor | None = None,
|
||||
hidden_states: Tensor | None = None,
|
||||
):
|
||||
self.logits = logits
|
||||
self.probabilities = probabilities
|
||||
self.hidden_states = hidden_states
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"ClassifierOutput(logits={self.logits}, "
|
||||
f"probabilities={self.probabilities}, "
|
||||
f"hidden_states={self.hidden_states})"
|
||||
)
|
||||
|
||||
|
||||
class SpatialLearnedEmbeddings(nn.Module):
|
||||
def __init__(self, height, width, channel, num_features=8):
|
||||
"""
|
||||
PyTorch implementation of learned spatial embeddings
|
||||
|
||||
Args:
|
||||
height: Spatial height of input features
|
||||
width: Spatial width of input features
|
||||
channel: Number of input channels
|
||||
num_features: Number of output embedding dimensions
|
||||
"""
|
||||
super().__init__()
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.channel = channel
|
||||
self.num_features = num_features
|
||||
|
||||
self.kernel = nn.Parameter(torch.empty(channel, height, width, num_features))
|
||||
|
||||
nn.init.kaiming_normal_(self.kernel, mode="fan_in", nonlinearity="linear")
|
||||
|
||||
def forward(self, features):
|
||||
"""
|
||||
Forward pass for spatial embedding
|
||||
|
||||
Args:
|
||||
features: Input tensor of shape [B, H, W, C] or [H, W, C] if no batch
|
||||
Returns:
|
||||
Output tensor of shape [B, C*F] or [C*F] if no batch
|
||||
"""
|
||||
|
||||
features = features.last_hidden_state
|
||||
|
||||
original_shape = features.shape
|
||||
if features.dim() == 3:
|
||||
features = features.unsqueeze(0) # Add batch dim
|
||||
|
||||
features_expanded = features.unsqueeze(-1) # [B, H, W, C, 1]
|
||||
kernel_expanded = self.kernel.unsqueeze(0) # [1, H, W, C, F]
|
||||
|
||||
# Element-wise multiplication and spatial reduction
|
||||
output = (features_expanded * kernel_expanded).sum(dim=(2, 3)) # Sum H,W
|
||||
|
||||
# Reshape to combine channel and feature dimensions
|
||||
output = output.view(output.size(0), -1) # [B, C*F]
|
||||
|
||||
# Remove batch dim
|
||||
if len(original_shape) == 3:
|
||||
output = output.squeeze(0)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class Classifier(PreTrainedRewardModel):
|
||||
"""Image classifier built on top of a pre-trained encoder."""
|
||||
|
||||
name = "reward_classifier"
|
||||
config_class = RewardClassifierConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: RewardClassifierConfig,
|
||||
):
|
||||
from transformers import AutoModel
|
||||
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Set up encoder
|
||||
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
# Extract vision model if we're given a multimodal model
|
||||
if hasattr(encoder, "vision_model"):
|
||||
logging.info("Multimodal model detected - using vision encoder only")
|
||||
self.encoder = encoder.vision_model
|
||||
self.vision_config = encoder.config.vision_config
|
||||
else:
|
||||
self.encoder = encoder
|
||||
self.vision_config = getattr(encoder, "config", None)
|
||||
|
||||
# Model type from config
|
||||
self.is_cnn = self.config.model_type == "cnn"
|
||||
|
||||
# For CNNs, initialize backbone
|
||||
if self.is_cnn:
|
||||
self._setup_cnn_backbone()
|
||||
|
||||
self._freeze_encoder()
|
||||
|
||||
# Extract image keys from input_features
|
||||
self.image_keys = [
|
||||
key.replace(".", "_") for key in config.input_features if key.startswith(OBS_IMAGE)
|
||||
]
|
||||
|
||||
if self.is_cnn:
|
||||
self.encoders = nn.ModuleDict()
|
||||
for image_key in self.image_keys:
|
||||
encoder = self._create_single_encoder()
|
||||
self.encoders[image_key] = encoder
|
||||
|
||||
self._build_classifier_head()
|
||||
|
||||
def _setup_cnn_backbone(self):
|
||||
"""Set up CNN encoder"""
|
||||
if hasattr(self.encoder, "fc"):
|
||||
self.feature_dim = self.encoder.fc.in_features
|
||||
self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])
|
||||
elif hasattr(self.encoder.config, "hidden_sizes"):
|
||||
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
|
||||
else:
|
||||
raise ValueError("Unsupported CNN architecture")
|
||||
|
||||
def _freeze_encoder(self) -> None:
|
||||
"""Freeze the encoder parameters."""
|
||||
for param in self.encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def _create_single_encoder(self):
|
||||
encoder = nn.Sequential(
|
||||
self.encoder,
|
||||
SpatialLearnedEmbeddings(
|
||||
height=4,
|
||||
width=4,
|
||||
channel=self.feature_dim,
|
||||
num_features=self.config.image_embedding_pooling_dim,
|
||||
),
|
||||
nn.Dropout(self.config.dropout_rate),
|
||||
nn.Linear(self.feature_dim * self.config.image_embedding_pooling_dim, self.config.latent_dim),
|
||||
nn.LayerNorm(self.config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
return encoder
|
||||
|
||||
def _build_classifier_head(self) -> None:
|
||||
"""Initialize the classifier head architecture."""
|
||||
# Get input dimension based on model type
|
||||
if self.is_cnn:
|
||||
input_dim = self.config.latent_dim
|
||||
else: # Transformer models
|
||||
if hasattr(self.encoder.config, "hidden_size"):
|
||||
input_dim = self.encoder.config.hidden_size
|
||||
else:
|
||||
raise ValueError("Unsupported transformer architecture since hidden_size is not found")
|
||||
|
||||
self.classifier_head = nn.Sequential(
|
||||
nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim),
|
||||
nn.Dropout(self.config.dropout_rate),
|
||||
nn.LayerNorm(self.config.hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(
|
||||
self.config.hidden_dim,
|
||||
1 if self.config.num_classes == 2 else self.config.num_classes,
|
||||
),
|
||||
)
|
||||
|
||||
def _get_encoder_output(self, x: torch.Tensor, image_key: str) -> torch.Tensor:
|
||||
"""Extract the appropriate output from the encoder."""
|
||||
with torch.no_grad():
|
||||
if self.is_cnn:
|
||||
# The HF ResNet applies pooling internally
|
||||
outputs = self.encoders[image_key](x)
|
||||
return outputs
|
||||
else: # Transformer models
|
||||
outputs = self.encoder(x)
|
||||
return outputs.last_hidden_state[:, 0, :]
|
||||
|
||||
def extract_images_and_labels(self, batch: dict[str, Tensor]) -> tuple[list, Tensor]:
|
||||
"""Extract image tensors and label tensors from batch."""
|
||||
# Check for both OBS_IMAGE and OBS_IMAGES prefixes
|
||||
images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)]
|
||||
labels = batch[REWARD]
|
||||
|
||||
return images, labels
|
||||
|
||||
def predict(self, xs: list) -> ClassifierOutput:
|
||||
"""Forward pass of the classifier for inference."""
|
||||
encoder_outputs = torch.hstack(
|
||||
[self._get_encoder_output(x, img_key) for x, img_key in zip(xs, self.image_keys, strict=True)]
|
||||
)
|
||||
logits = self.classifier_head(encoder_outputs)
|
||||
|
||||
if self.config.num_classes == 2:
|
||||
logits = logits.squeeze(-1)
|
||||
probabilities = torch.sigmoid(logits)
|
||||
else:
|
||||
probabilities = torch.softmax(logits, dim=-1)
|
||||
|
||||
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
|
||||
|
||||
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Returns 1.0 for success, 0.0 for failure based on image observations."""
|
||||
images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)]
|
||||
output = self.predict(images)
|
||||
|
||||
if self.config.num_classes == 2:
|
||||
return (output.probabilities > 0.5).float()
|
||||
else:
|
||||
return torch.argmax(output.probabilities, dim=1).float()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]:
|
||||
"""Standard forward pass for training compatible with train.py."""
|
||||
# Extract images and labels
|
||||
images, labels = self.extract_images_and_labels(batch)
|
||||
|
||||
# Get predictions
|
||||
outputs = self.predict(images)
|
||||
|
||||
# Calculate loss
|
||||
if self.config.num_classes == 2:
|
||||
# Binary classification
|
||||
loss = nn.functional.binary_cross_entropy_with_logits(outputs.logits, labels)
|
||||
predictions = (torch.sigmoid(outputs.logits) > 0.5).float()
|
||||
else:
|
||||
# Multi-class classification
|
||||
loss = nn.functional.cross_entropy(outputs.logits, labels.long())
|
||||
predictions = torch.argmax(outputs.logits, dim=1)
|
||||
|
||||
# Calculate accuracy for logging
|
||||
correct = (predictions == labels).sum().item()
|
||||
total = labels.size(0)
|
||||
accuracy = 100 * correct / total
|
||||
|
||||
# Return loss and metrics for logging
|
||||
output_dict = {
|
||||
"accuracy": accuracy,
|
||||
"correct": correct,
|
||||
"total": total,
|
||||
}
|
||||
|
||||
return loss, output_dict
|
||||
|
||||
def predict_reward(self, batch, threshold=0.5):
|
||||
"""Eval method. Returns predicted reward with the decision threshold as argument."""
|
||||
# Extract images from batch dict
|
||||
images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)]
|
||||
|
||||
if self.config.num_classes == 2:
|
||||
probs = self.predict(images).probabilities
|
||||
logging.debug(f"Predicted reward images: {probs}")
|
||||
return (probs > threshold).float()
|
||||
else:
|
||||
return torch.argmax(self.predict(images).probabilities, dim=1)
|
||||
Reference in New Issue
Block a user