diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index 58b5dc07b..450fe25ce 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -59,6 +59,12 @@ class ActionSelectKwargs(TypedDict, total=False): execution_horizon: int | None +def _gated_residual(residual: torch.Tensor, hidden_states: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + """Gated residual connection: residual + gate * hidden_states. + """ + return residual + gate.unsqueeze(-1) * hidden_states + + def get_safe_dtype(target_dtype, device_type): """Get a safe dtype for the given device type.""" if device_type == "mps" and target_dtype == torch.float64: @@ -217,18 +223,53 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) return padded_images +class AdaRMSNorm(nn.Module): + """RMSNorm wrapper that supports optional AdaRMS conditioning. + + When called with `cond=None`, behaves like standard RMSNorm and returns a gate of ones. + When called with a conditioning tensor, applies AdaRMS: uses a linear projection to produce + a scale and gate from the conditioning input. + """ + + def __init__(self, base_norm: nn.Module, cond_dim: int | None = None): + super().__init__() + self.base_norm = base_norm + if cond_dim is not None: + hidden_size = base_norm.weight.shape[0] + self.ada_proj = nn.Linear(cond_dim, 2 * hidden_size, bias=False) + nn.init.zeros_(self.ada_proj.weight) + else: + self.ada_proj = None + + def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None): + normed = self.base_norm(x) + if cond is None or self.ada_proj is None: + gate = torch.ones(x.shape[:-1], dtype=x.dtype, device=x.device) + return normed, gate + scale_gate = self.ada_proj(cond) + scale, gate = scale_gate.chunk(2, dim=-1) + normed = normed * (1 + scale) + return normed, gate + + # Define the complete layer computation function for gradient checkpointing def compute_layer_complete( layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert ): - models = [paligemma.language_model, gemma_expert.model] + models = [paligemma.model.language_model, gemma_expert.model] query_states = [] key_states = [] value_states = [] gates = [] for i, hidden_states in enumerate(inputs_embeds): layer = models[i].layers[layer_idx] - hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901 + if isinstance(layer.input_layernorm, AdaRMSNorm): + hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901 + else: + hidden_states = layer.input_layernorm(hidden_states) # noqa: PLW2901 + gate = torch.ones( + hidden_states.shape[:-1], dtype=hidden_states.dtype, device=hidden_states.device + ) gates.append(gate) input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) @@ -254,10 +295,10 @@ def compute_layer_complete( query_states, key_states, cos, sin, unsqueeze_dim=1 ) batch_size = query_states.shape[0] - scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling + scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling # Attention computation att_output, _ = modeling_gemma.eager_attention_forward( - paligemma.language_model.layers[layer_idx].self_attn, + paligemma.model.language_model.layers[layer_idx].self_attn, query_states, key_states, value_states, @@ -265,7 +306,7 @@ def compute_layer_complete( scaling, ) # Get head_dim from the current layer, not from the model - head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim + head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim) # Process layer outputs outputs_embeds = [] @@ -277,15 +318,19 @@ def compute_layer_complete( att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos]) # first residual - out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001 + out_emb = _gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001 after_first_residual = out_emb.clone() - out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i]) + if isinstance(layer.post_attention_layernorm, AdaRMSNorm): + out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i]) + else: + out_emb = layer.post_attention_layernorm(out_emb) + gate = torch.ones(out_emb.shape[:-1], dtype=out_emb.dtype, device=out_emb.device) # Convert to bfloat16 if the next layer (mlp) uses bfloat16 if layer.mlp.up_proj.weight.dtype == torch.bfloat16: out_emb = out_emb.to(dtype=torch.bfloat16) out_emb = layer.mlp(out_emb) # second residual - out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001 + out_emb = _gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001 outputs_embeds.append(out_emb) start_pos = end_pos return outputs_embeds @@ -413,8 +458,8 @@ class PaliGemmaWithExpertModel( def _set_requires_grad(self): if self.freeze_vision_encoder: - self.paligemma.vision_tower.eval() - for param in self.paligemma.vision_tower.parameters(): + self.paligemma.model.vision_tower.eval() + for param in self.paligemma.model.vision_tower.parameters(): param.requires_grad = False if self.train_expert_only: self.paligemma.eval() @@ -424,7 +469,7 @@ class PaliGemmaWithExpertModel( def train(self, mode: bool = True): super().train(mode) if self.freeze_vision_encoder: - self.paligemma.vision_tower.eval() + self.paligemma.model.vision_tower.eval() if self.train_expert_only: self.paligemma.eval() @@ -432,7 +477,7 @@ class PaliGemmaWithExpertModel( return self.paligemma.model.get_image_features(image) def embed_language_tokens(self, tokens: torch.Tensor): - return self.paligemma.language_model.embed_tokens(tokens) + return self.paligemma.model.language_model.embed_tokens(tokens) def forward( self, @@ -446,7 +491,7 @@ class PaliGemmaWithExpertModel( if adarms_cond is None: adarms_cond = [None, None] if inputs_embeds[1] is None: - prefix_output = self.paligemma.language_model.forward( + prefix_output = self.paligemma.model.language_model.forward( inputs_embeds=inputs_embeds[0], attention_mask=attention_mask, position_ids=position_ids, @@ -470,7 +515,7 @@ class PaliGemmaWithExpertModel( prefix_output = None prefix_past_key_values = None else: - models = [self.paligemma.language_model, self.gemma_expert.model] + models = [self.paligemma.model.language_model, self.gemma_expert.model] num_layers = self.paligemma.config.text_config.num_hidden_layers # Check if gradient checkpointing is enabled for any of the models @@ -510,7 +555,11 @@ class PaliGemmaWithExpertModel( def compute_final_norms(inputs_embeds, adarms_cond): outputs_embeds = [] for i, hidden_states in enumerate(inputs_embeds): - out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i]) + norm = models[i].norm + if isinstance(norm, AdaRMSNorm): + out_emb, _ = norm(hidden_states, cond=adarms_cond[i]) + else: + out_emb = norm(hidden_states) outputs_embeds.append(out_emb) return outputs_embeds @@ -576,29 +625,19 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` # Also compile the main forward pass used during training self.forward = torch.compile(self.forward, mode=config.compile_mode) - msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues""" - - try: - from transformers.models.siglip import check - - if not check.check_whether_transformers_replace_is_installed_correctly(): - raise ValueError(msg) - except ImportError: - raise ValueError(msg) from None - def gradient_checkpointing_enable(self): """Enable gradient checkpointing for memory optimization.""" self.gradient_checkpointing_enabled = True - self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True - self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True + self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True + self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True logging.info("Enabled gradient checkpointing for PI0Pytorch model") def gradient_checkpointing_disable(self): """Disable gradient checkpointing.""" self.gradient_checkpointing_enabled = False - self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False - self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False + self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False + self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False logging.info("Disabled gradient checkpointing for PI0Pytorch model") @@ -760,7 +799,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time) if ( - self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype + self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16 ): suffix_embs = suffix_embs.to(dtype=torch.bfloat16) @@ -834,7 +873,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks) - self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001 + self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001 _, past_key_values = self.paligemma_with_expert.forward( attention_mask=prefix_att_2d_masks_4d, @@ -1012,7 +1051,7 @@ class PI0Policy(PreTrainedPolicy): force_download=kwargs.get("force_download", False), resume_download=kwargs.get("resume_download"), proxies=kwargs.get("proxies"), - use_auth_token=kwargs.get("use_auth_token"), + token=kwargs.get("token"), revision=kwargs.get("revision"), local_files_only=kwargs.get("local_files_only", False), ) diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index 104ec63bf..2b0965e39 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -58,6 +58,12 @@ class ActionSelectKwargs(TypedDict, total=False): execution_horizon: int | None +def _gated_residual(residual: torch.Tensor, hidden_states: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + """Gated residual connection: residual + gate * hidden_states. + """ + return residual + gate.unsqueeze(-1) * hidden_states + + def get_safe_dtype(target_dtype, device_type): """Get a safe dtype for the given device type.""" if device_type == "mps" and target_dtype == torch.float64: @@ -215,18 +221,53 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) return padded_images +class AdaRMSNorm(nn.Module): + """RMSNorm wrapper that supports optional AdaRMS conditioning. + + When called with `cond=None`, behaves like standard RMSNorm and returns a gate of ones. + When called with a conditioning tensor, applies AdaRMS: uses a linear projection to produce + a scale and gate from the conditioning input. + """ + + def __init__(self, base_norm: nn.Module, cond_dim: int | None = None): + super().__init__() + self.base_norm = base_norm + if cond_dim is not None: + hidden_size = base_norm.weight.shape[0] + self.ada_proj = nn.Linear(cond_dim, 2 * hidden_size, bias=False) + nn.init.zeros_(self.ada_proj.weight) + else: + self.ada_proj = None + + def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None): + normed = self.base_norm(x) + if cond is None or self.ada_proj is None: + gate = torch.ones(x.shape[:-1], dtype=x.dtype, device=x.device) + return normed, gate + scale_gate = self.ada_proj(cond) + scale, gate = scale_gate.chunk(2, dim=-1) + normed = normed * (1 + scale) + return normed, gate + + # Define the complete layer computation function for gradient checkpointing def compute_layer_complete( layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert ): - models = [paligemma.language_model, gemma_expert.model] + models = [paligemma.model.language_model, gemma_expert.model] query_states = [] key_states = [] value_states = [] gates = [] for i, hidden_states in enumerate(inputs_embeds): layer = models[i].layers[layer_idx] - hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901 + if isinstance(layer.input_layernorm, AdaRMSNorm): + hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901 + else: + hidden_states = layer.input_layernorm(hidden_states) # noqa: PLW2901 + gate = torch.ones( + hidden_states.shape[:-1], dtype=hidden_states.dtype, device=hidden_states.device + ) gates.append(gate) input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) @@ -252,10 +293,10 @@ def compute_layer_complete( query_states, key_states, cos, sin, unsqueeze_dim=1 ) batch_size = query_states.shape[0] - scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling + scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling # Attention computation att_output, _ = modeling_gemma.eager_attention_forward( - paligemma.language_model.layers[layer_idx].self_attn, + paligemma.model.language_model.layers[layer_idx].self_attn, query_states, key_states, value_states, @@ -263,7 +304,7 @@ def compute_layer_complete( scaling, ) # Get head_dim from the current layer, not from the model - head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim + head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim) # Process layer outputs outputs_embeds = [] @@ -275,15 +316,19 @@ def compute_layer_complete( att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos]) # first residual - out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001 + out_emb = _gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001 after_first_residual = out_emb.clone() - out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i]) + if isinstance(layer.post_attention_layernorm, AdaRMSNorm): + out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i]) + else: + out_emb = layer.post_attention_layernorm(out_emb) + gate = torch.ones(out_emb.shape[:-1], dtype=out_emb.dtype, device=out_emb.device) # Convert to bfloat16 if the next layer (mlp) uses bfloat16 if layer.mlp.up_proj.weight.dtype == torch.bfloat16: out_emb = out_emb.to(dtype=torch.bfloat16) out_emb = layer.mlp(out_emb) # second residual - out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001 + out_emb = _gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001 outputs_embeds.append(out_emb) start_pos = end_pos return outputs_embeds @@ -411,8 +456,8 @@ class PaliGemmaWithExpertModel( def _set_requires_grad(self): if self.freeze_vision_encoder: - self.paligemma.vision_tower.eval() - for param in self.paligemma.vision_tower.parameters(): + self.paligemma.model.vision_tower.eval() + for param in self.paligemma.model.vision_tower.parameters(): param.requires_grad = False if self.train_expert_only: self.paligemma.eval() @@ -422,7 +467,7 @@ class PaliGemmaWithExpertModel( def train(self, mode: bool = True): super().train(mode) if self.freeze_vision_encoder: - self.paligemma.vision_tower.eval() + self.paligemma.model.vision_tower.eval() if self.train_expert_only: self.paligemma.eval() @@ -430,7 +475,7 @@ class PaliGemmaWithExpertModel( return self.paligemma.model.get_image_features(image) def embed_language_tokens(self, tokens: torch.Tensor): - return self.paligemma.language_model.embed_tokens(tokens) + return self.paligemma.model.language_model.embed_tokens(tokens) def forward( self, @@ -444,7 +489,7 @@ class PaliGemmaWithExpertModel( if adarms_cond is None: adarms_cond = [None, None] if inputs_embeds[1] is None: - prefix_output = self.paligemma.language_model.forward( + prefix_output = self.paligemma.model.language_model.forward( inputs_embeds=inputs_embeds[0], attention_mask=attention_mask, position_ids=position_ids, @@ -468,7 +513,7 @@ class PaliGemmaWithExpertModel( prefix_output = None prefix_past_key_values = None else: - models = [self.paligemma.language_model, self.gemma_expert.model] + models = [self.paligemma.model.language_model, self.gemma_expert.model] num_layers = self.paligemma.config.text_config.num_hidden_layers # Check if gradient checkpointing is enabled for any of the models @@ -508,7 +553,11 @@ class PaliGemmaWithExpertModel( def compute_final_norms(inputs_embeds, adarms_cond): outputs_embeds = [] for i, hidden_states in enumerate(inputs_embeds): - out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i]) + norm = models[i].norm + if isinstance(norm, AdaRMSNorm): + out_emb, _ = norm(hidden_states, cond=adarms_cond[i]) + else: + out_emb = norm(hidden_states) outputs_embeds.append(out_emb) return outputs_embeds @@ -573,29 +622,19 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` # Also compile the main forward pass used during training self.forward = torch.compile(self.forward, mode=config.compile_mode) - msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues""" - - try: - from transformers.models.siglip import check - - if not check.check_whether_transformers_replace_is_installed_correctly(): - raise ValueError(msg) - except ImportError: - raise ValueError(msg) from None - def gradient_checkpointing_enable(self): """Enable gradient checkpointing for memory optimization.""" self.gradient_checkpointing_enabled = True - self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True - self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True + self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True + self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True logging.info("Enabled gradient checkpointing for PI05Pytorch model") def gradient_checkpointing_disable(self): """Disable gradient checkpointing.""" self.gradient_checkpointing_enabled = False - self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False - self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False + self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False + self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False logging.info("Disabled gradient checkpointing for PI05Pytorch model") @@ -737,7 +776,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time) if ( - self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype + self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16 ): suffix_embs = suffix_embs.to(dtype=torch.bfloat16) @@ -808,7 +847,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks) - self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001 + self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001 _, past_key_values = self.paligemma_with_expert.forward( attention_mask=prefix_att_2d_masks_4d, @@ -984,7 +1023,7 @@ class PI05Policy(PreTrainedPolicy): force_download=kwargs.get("force_download", False), resume_download=kwargs.get("resume_download"), proxies=kwargs.get("proxies"), - use_auth_token=kwargs.get("use_auth_token"), + token=kwargs.get("token"), revision=kwargs.get("revision"), local_files_only=kwargs.get("local_files_only", False), ) diff --git a/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py b/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py index b4bc7ba22..47e1df8db 100644 --- a/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py +++ b/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py @@ -245,7 +245,7 @@ class PI0FastPaliGemma(nn.Module): return self.paligemma.model.get_image_features(image) def embed_language_tokens(self, tokens: torch.Tensor): - return self.paligemma.language_model.embed_tokens(tokens) + return self.paligemma.model.language_model.embed_tokens(tokens) def forward( self, @@ -259,7 +259,7 @@ class PI0FastPaliGemma(nn.Module): if adarms_cond is None: adarms_cond = [None, None] if inputs_embeds[1] is None: - prefix_output = self.paligemma.language_model.forward( + prefix_output = self.paligemma.model.language_model.forward( inputs_embeds=inputs_embeds[0], attention_mask=attention_mask, position_ids=position_ids, @@ -306,24 +306,14 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch` self.sample_actions_fast = torch.compile(self.sample_actions_fast, mode=config.compile_mode) self.forward = torch.compile(self.forward, mode=config.compile_mode) - msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues""" - - try: - from transformers.models.siglip import check - - if not check.check_whether_transformers_replace_is_installed_correctly(): - raise ValueError(msg) - except ImportError: - raise ValueError(msg) from None - def gradient_checkpointing_enable(self): """Enable gradient checkpointing for memory optimization.""" self.gradient_checkpointing_enabled = True # Call the proper gradient_checkpointing_enable() method with use_reentrant=False for better memory efficiency - self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing_enable( + self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing_enable( gradient_checkpointing_kwargs={"use_reentrant": False} ) - self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing_enable( + self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing_enable( gradient_checkpointing_kwargs={"use_reentrant": False} ) logging.info("Enabled gradient checkpointing for PI0FastPytorch model") @@ -332,8 +322,8 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch` """Disable gradient checkpointing.""" self.gradient_checkpointing_enabled = False # Call the proper gradient_checkpointing_disable() method - self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing_disable() - self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing_disable() + self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing_disable() + self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing_disable() logging.info("Disabled gradient checkpointing for PI0FastPytorch model") def _apply_checkpoint(self, func, *args, **kwargs): @@ -523,7 +513,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch` # Convert embeddings to bfloat16 if needed if ( - self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype + self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16 ): prefix_embs = prefix_embs.to(dtype=torch.bfloat16) @@ -616,7 +606,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch` ) if ( - self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype + self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16 ): prefix_embs = prefix_embs.to(dtype=torch.bfloat16) @@ -714,7 +704,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch` # Ensure correct precision (bfloat16/float32) if ( - self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype + self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16 ): prefix_embs = prefix_embs.to(dtype=torch.bfloat16) @@ -912,7 +902,7 @@ class PI0FastPolicy(PreTrainedPolicy): force_download=kwargs.get("force_download", False), resume_download=kwargs.get("resume_download"), proxies=kwargs.get("proxies"), - use_auth_token=kwargs.get("use_auth_token"), + token=kwargs.get("token"), revision=kwargs.get("revision"), local_files_only=kwargs.get("local_files_only", False), )