mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 10:51:35 +00:00
refactor(model): change tensor data type from bfloat16 to float32
- Updated image and state embeddings to use float32 for improved compatibility. - Adjusted model parameters and hidden states to ensure consistent data type usage.
This commit is contained in:
@@ -493,7 +493,7 @@ class PI0FlowMatching(nn.Module):
|
||||
img_mask,
|
||||
) in zip(images, img_masks, strict=False):
|
||||
img_emb = self.paligemma_with_expert.embed_image(img)
|
||||
img_emb = img_emb.to(dtype=torch.bfloat16)
|
||||
img_emb = img_emb.to(dtype=torch.float32)
|
||||
|
||||
# Normalize image embeddings
|
||||
img_emb_dim = img_emb.shape[-1]
|
||||
@@ -536,7 +536,7 @@ class PI0FlowMatching(nn.Module):
|
||||
|
||||
# Embed state
|
||||
state_emb = self.state_proj(state)
|
||||
state_emb = state_emb.to(dtype=torch.bfloat16)
|
||||
state_emb = state_emb.to(dtype=torch.float32)
|
||||
embs.append(state_emb[:, None, :])
|
||||
bsize = state_emb.shape[0]
|
||||
dtype = state_emb.dtype
|
||||
|
||||
@@ -202,7 +202,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
self.paligemma.eval()
|
||||
|
||||
def to_bfloat16_like_physical_intelligence(self):
|
||||
self.paligemma = self.paligemma.to(dtype=torch.bfloat16)
|
||||
self.paligemma = self.paligemma.to(dtype=torch.float32)
|
||||
|
||||
params_to_change_dtype = [
|
||||
"language_model.model.layers",
|
||||
@@ -212,7 +212,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
]
|
||||
for name, param in self.named_parameters():
|
||||
if any(selector in name for selector in params_to_change_dtype):
|
||||
param.data = param.data.to(dtype=torch.bfloat16)
|
||||
param.data = param.data.to(dtype=torch.float32)
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
# Handle different transformers versions
|
||||
@@ -262,7 +262,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
|
||||
hidden_states = hidden_states.to(dtype=torch.bfloat16)
|
||||
hidden_states = hidden_states.to(dtype=torch.float32)
|
||||
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
|
||||
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
|
||||
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
|
||||
@@ -303,7 +303,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
att_output = attention_interface(
|
||||
attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
)
|
||||
att_output = att_output.to(dtype=torch.bfloat16)
|
||||
att_output = att_output.to(dtype=torch.float32)
|
||||
|
||||
# first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len])
|
||||
outputs_embeds = []
|
||||
|
||||
Reference in New Issue
Block a user