diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index 66bd81e61..8b43334e4 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -493,7 +493,7 @@ class PI0FlowMatching(nn.Module): img_mask, ) in zip(images, img_masks, strict=False): img_emb = self.paligemma_with_expert.embed_image(img) - img_emb = img_emb.to(dtype=torch.bfloat16) + img_emb = img_emb.to(dtype=torch.float32) # Normalize image embeddings img_emb_dim = img_emb.shape[-1] @@ -536,7 +536,7 @@ class PI0FlowMatching(nn.Module): # Embed state state_emb = self.state_proj(state) - state_emb = state_emb.to(dtype=torch.bfloat16) + state_emb = state_emb.to(dtype=torch.float32) embs.append(state_emb[:, None, :]) bsize = state_emb.shape[0] dtype = state_emb.dtype diff --git a/src/lerobot/policies/pi0/paligemma_with_expert.py b/src/lerobot/policies/pi0/paligemma_with_expert.py index edc34b7c5..f652561af 100644 --- a/src/lerobot/policies/pi0/paligemma_with_expert.py +++ b/src/lerobot/policies/pi0/paligemma_with_expert.py @@ -202,7 +202,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel): self.paligemma.eval() def to_bfloat16_like_physical_intelligence(self): - self.paligemma = self.paligemma.to(dtype=torch.bfloat16) + self.paligemma = self.paligemma.to(dtype=torch.float32) params_to_change_dtype = [ "language_model.model.layers", @@ -212,7 +212,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel): ] for name, param in self.named_parameters(): if any(selector in name for selector in params_to_change_dtype): - param.data = param.data.to(dtype=torch.bfloat16) + param.data = param.data.to(dtype=torch.float32) def embed_image(self, image: torch.Tensor): # Handle different transformers versions @@ -262,7 +262,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel): input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) - hidden_states = hidden_states.to(dtype=torch.bfloat16) + hidden_states = hidden_states.to(dtype=torch.float32) query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape) key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape) value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape) @@ -303,7 +303,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel): att_output = attention_interface( attention_mask, batch_size, head_dim, query_states, key_states, value_states ) - att_output = att_output.to(dtype=torch.bfloat16) + att_output = att_output.to(dtype=torch.float32) # first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len]) outputs_embeds = []