# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import torch from torch import Tensor, nn from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig from lerobot.rewards.pretrained import PreTrainedRewardModel from lerobot.utils.constants import OBS_IMAGE, REWARD class ClassifierOutput: """Wrapper for classifier outputs with additional metadata.""" def __init__( self, logits: Tensor, probabilities: Tensor | None = None, hidden_states: Tensor | None = None, ): self.logits = logits self.probabilities = probabilities self.hidden_states = hidden_states def __repr__(self): return ( f"ClassifierOutput(logits={self.logits}, " f"probabilities={self.probabilities}, " f"hidden_states={self.hidden_states})" ) class SpatialLearnedEmbeddings(nn.Module): def __init__(self, height, width, channel, num_features=8): """ PyTorch implementation of learned spatial embeddings Args: height: Spatial height of input features width: Spatial width of input features channel: Number of input channels num_features: Number of output embedding dimensions """ super().__init__() self.height = height self.width = width self.channel = channel self.num_features = num_features self.kernel = nn.Parameter(torch.empty(channel, height, width, num_features)) nn.init.kaiming_normal_(self.kernel, mode="fan_in", nonlinearity="linear") def forward(self, features): """ Forward pass for spatial embedding Args: features: Input tensor of shape [B, H, W, C] or [H, W, C] if no batch Returns: Output tensor of shape [B, C*F] or [C*F] if no batch """ features = features.last_hidden_state original_shape = features.shape if features.dim() == 3: features = features.unsqueeze(0) # Add batch dim features_expanded = features.unsqueeze(-1) # [B, H, W, C, 1] kernel_expanded = self.kernel.unsqueeze(0) # [1, H, W, C, F] # Element-wise multiplication and spatial reduction output = (features_expanded * kernel_expanded).sum(dim=(2, 3)) # Sum H,W # Reshape to combine channel and feature dimensions output = output.view(output.size(0), -1) # [B, C*F] # Remove batch dim if len(original_shape) == 3: output = output.squeeze(0) return output class Classifier(PreTrainedRewardModel): """Image classifier built on top of a pre-trained encoder.""" name = "reward_classifier" config_class = RewardClassifierConfig def __init__( self, config: RewardClassifierConfig, ): from transformers import AutoModel super().__init__(config) self.config = config # Set up encoder encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True) # Extract vision model if we're given a multimodal model if hasattr(encoder, "vision_model"): logging.info("Multimodal model detected - using vision encoder only") self.encoder = encoder.vision_model self.vision_config = encoder.config.vision_config else: self.encoder = encoder self.vision_config = getattr(encoder, "config", None) # Model type from config self.is_cnn = self.config.model_type == "cnn" # For CNNs, initialize backbone if self.is_cnn: self._setup_cnn_backbone() self._freeze_encoder() # Extract image keys from input_features self.image_keys = [ key.replace(".", "_") for key in config.input_features if key.startswith(OBS_IMAGE) ] if self.is_cnn: self.encoders = nn.ModuleDict() for image_key in self.image_keys: encoder = self._create_single_encoder() self.encoders[image_key] = encoder self._build_classifier_head() def _setup_cnn_backbone(self): """Set up CNN encoder""" if hasattr(self.encoder, "fc"): self.feature_dim = self.encoder.fc.in_features self.encoder = nn.Sequential(*list(self.encoder.children())[:-1]) elif hasattr(self.encoder.config, "hidden_sizes"): self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension else: raise ValueError("Unsupported CNN architecture") def _freeze_encoder(self) -> None: """Freeze the encoder parameters.""" for param in self.encoder.parameters(): param.requires_grad = False def _create_single_encoder(self): encoder = nn.Sequential( self.encoder, SpatialLearnedEmbeddings( height=4, width=4, channel=self.feature_dim, num_features=self.config.image_embedding_pooling_dim, ), nn.Dropout(self.config.dropout_rate), nn.Linear(self.feature_dim * self.config.image_embedding_pooling_dim, self.config.latent_dim), nn.LayerNorm(self.config.latent_dim), nn.Tanh(), ) return encoder def _build_classifier_head(self) -> None: """Initialize the classifier head architecture.""" # Get input dimension based on model type if self.is_cnn: input_dim = self.config.latent_dim else: # Transformer models if hasattr(self.encoder.config, "hidden_size"): input_dim = self.encoder.config.hidden_size else: raise ValueError("Unsupported transformer architecture since hidden_size is not found") self.classifier_head = nn.Sequential( nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim), nn.Dropout(self.config.dropout_rate), nn.LayerNorm(self.config.hidden_dim), nn.ReLU(), nn.Linear( self.config.hidden_dim, 1 if self.config.num_classes == 2 else self.config.num_classes, ), ) def _get_encoder_output(self, x: torch.Tensor, image_key: str) -> torch.Tensor: """Extract the appropriate output from the encoder.""" with torch.no_grad(): if self.is_cnn: # The HF ResNet applies pooling internally outputs = self.encoders[image_key](x) return outputs else: # Transformer models outputs = self.encoder(x) return outputs.last_hidden_state[:, 0, :] def extract_images_and_labels(self, batch: dict[str, Tensor]) -> tuple[list, Tensor]: """Extract image tensors and label tensors from batch.""" # Check for both OBS_IMAGE and OBS_IMAGES prefixes images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)] labels = batch[REWARD] return images, labels def predict(self, xs: list) -> ClassifierOutput: """Forward pass of the classifier for inference.""" encoder_outputs = torch.hstack( [self._get_encoder_output(x, img_key) for x, img_key in zip(xs, self.image_keys, strict=True)] ) logits = self.classifier_head(encoder_outputs) if self.config.num_classes == 2: logits = logits.squeeze(-1) probabilities = torch.sigmoid(logits) else: probabilities = torch.softmax(logits, dim=-1) return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs) def compute_reward(self, batch: dict[str, Tensor]) -> Tensor: """Returns 1.0 for success, 0.0 for failure based on image observations.""" images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)] output = self.predict(images) if self.config.num_classes == 2: return (output.probabilities > 0.5).float() else: return torch.argmax(output.probabilities, dim=1).float() def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]: """Standard forward pass for training compatible with train.py.""" # Extract images and labels images, labels = self.extract_images_and_labels(batch) # Get predictions outputs = self.predict(images) # Calculate loss if self.config.num_classes == 2: # Binary classification loss = nn.functional.binary_cross_entropy_with_logits(outputs.logits, labels) predictions = (torch.sigmoid(outputs.logits) > 0.5).float() else: # Multi-class classification loss = nn.functional.cross_entropy(outputs.logits, labels.long()) predictions = torch.argmax(outputs.logits, dim=1) # Calculate accuracy for logging correct = (predictions == labels).sum().item() total = labels.size(0) accuracy = 100 * correct / total # Return loss and metrics for logging output_dict = { "accuracy": accuracy, "correct": correct, "total": total, } return loss, output_dict def predict_reward(self, batch, threshold=0.5): """Eval method. Returns predicted reward with the decision threshold as argument.""" # Extract images from batch dict images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)] if self.config.num_classes == 2: probs = self.predict(images).probabilities logging.debug(f"Predicted reward images: {probs}") return (probs > threshold).float() else: return torch.argmax(self.predict(images).probabilities, dim=1)