From fc6c94c82a4624bdfeffffc7a30dd00c67b2065c Mon Sep 17 00:00:00 2001 From: masato-ka Date: Thu, 23 Apr 2026 23:26:58 +0900 Subject: [PATCH] =?UTF-8?q?fix(sarm):=20handle=20BaseModelOutputWithPoolin?= =?UTF-8?q?g=20from=20transformers=205.x=20in=E2=80=A6=20(#3419)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(sarm): handle BaseModelOutputWithPooling from transformers 5.x in CLIP encoding In transformers 5.x, CLIPModel.get_image_features() and get_text_features() return BaseModelOutputWithPooling instead of a plain torch.FloatTensor. Added isinstance check to extract pooler_output when the return value is not a tensor, maintaining backward compatibility with transformers 4.x. Fixes AttributeError: 'BaseModelOutputWithPooling' object has no attribute 'detach' * Adding assertion check for pooler_output of CLIP. This change is response to below comment. https://github.com/huggingface/lerobot/pull/3419#discussion_r3112594387 * Adding assertion check for pooler_output of CLIP. This change is response to below comment. Change to simple check and rise https://github.com/huggingface/lerobot/pull/3419#discussion_r3126953776 --------- Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> --- src/lerobot/policies/sarm/processor_sarm.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/lerobot/policies/sarm/processor_sarm.py b/src/lerobot/policies/sarm/processor_sarm.py index e939b3485..b60271b49 100644 --- a/src/lerobot/policies/sarm/processor_sarm.py +++ b/src/lerobot/policies/sarm/processor_sarm.py @@ -455,7 +455,13 @@ class SARMEncodingProcessorStep(ProcessorStep): inputs = {k: v.to(self.device) for k, v in inputs.items()} # Get image embeddings - embeddings = self.clip_model.get_image_features(**inputs).detach().cpu() + # transformers 5.x returns BaseModelOutputWithPooling instead of a plain tensor + output = self.clip_model.get_image_features(**inputs) + if not isinstance(output, torch.Tensor): + output = output.pooler_output + if output is None: + raise ValueError("pooler_output should not be None for CLIP models.") + embeddings = output.detach().cpu() # Handle single frame case if embeddings.dim() == 1: @@ -482,7 +488,13 @@ class SARMEncodingProcessorStep(ProcessorStep): inputs = self.clip_processor.tokenizer([text], return_tensors="pt", padding=True, truncation=True) inputs = {k: v.to(self.device) for k, v in inputs.items()} - text_embedding = self.clip_model.get_text_features(**inputs).detach().cpu() + # transformers 5.x returns BaseModelOutputWithPooling instead of a plain tensor + output = self.clip_model.get_text_features(**inputs) + if not isinstance(output, torch.Tensor): + output = output.pooler_output + if output is None: + raise ValueError("pooler_output should not be None for CLIP models.") + text_embedding = output.detach().cpu() text_embedding = text_embedding.expand(batch_size, -1) return text_embedding