refactor(model): change tensor data type from bfloat16 to float32

- Updated image and state embeddings to use float32 for improved compatibility.
- Adjusted model parameters and hidden states to ensure consistent data type usage.
This commit is contained in:
AdilZouitine
2025-09-24 14:33:11 +02:00
parent cffd545527
commit b1d72ac29c
2 changed files with 6 additions and 6 deletions

View File

@@ -493,7 +493,7 @@ class PI0FlowMatching(nn.Module):
img_mask,
) in zip(images, img_masks, strict=False):
img_emb = self.paligemma_with_expert.embed_image(img)
img_emb = img_emb.to(dtype=torch.bfloat16)
img_emb = img_emb.to(dtype=torch.float32)
# Normalize image embeddings
img_emb_dim = img_emb.shape[-1]
@@ -536,7 +536,7 @@ class PI0FlowMatching(nn.Module):
# Embed state
state_emb = self.state_proj(state)
state_emb = state_emb.to(dtype=torch.bfloat16)
state_emb = state_emb.to(dtype=torch.float32)
embs.append(state_emb[:, None, :])
bsize = state_emb.shape[0]
dtype = state_emb.dtype

View File

@@ -202,7 +202,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
self.paligemma.eval()
def to_bfloat16_like_physical_intelligence(self):
self.paligemma = self.paligemma.to(dtype=torch.bfloat16)
self.paligemma = self.paligemma.to(dtype=torch.float32)
params_to_change_dtype = [
"language_model.model.layers",
@@ -212,7 +212,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
]
for name, param in self.named_parameters():
if any(selector in name for selector in params_to_change_dtype):
param.data = param.data.to(dtype=torch.bfloat16)
param.data = param.data.to(dtype=torch.float32)
def embed_image(self, image: torch.Tensor):
# Handle different transformers versions
@@ -262,7 +262,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
hidden_states = hidden_states.to(dtype=torch.bfloat16)
hidden_states = hidden_states.to(dtype=torch.float32)
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
@@ -303,7 +303,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
att_output = attention_interface(
attention_mask, batch_size, head_dim, query_states, key_states, value_states
)
att_output = att_output.to(dtype=torch.bfloat16)
att_output = att_output.to(dtype=torch.float32)
# first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len])
outputs_embeds = []