diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index dd49a45f7..2153f806a 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -257,8 +257,9 @@ class SmolVLAPolicy(PreTrainedPolicy): """ self.rtc_processor = None - if self.config.rtc_config is not None and self.config.rtc_config.enabled: - self.rtc_processor = RTCProcessor(self.config.rtc_config) + rtc_config = getattr(self.config, "rtc_config", None) + if rtc_config is not None and rtc_config.enabled: + self.rtc_processor = RTCProcessor(rtc_config) # In case of calling init_rtc_processor after the model is created # We need to set the rtc_processor to the model @@ -343,7 +344,8 @@ class SmolVLAPolicy(PreTrainedPolicy): return len(self._queues[ACTION]) == 0 def _rtc_enabled(self) -> bool: - return self.config.rtc_config is not None and self.config.rtc_config.enabled + rtc_config = getattr(self.config, "rtc_config", None) + return rtc_config is not None and rtc_config.enabled def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]: """Do a full training forward pass to compute the loss""" @@ -806,10 +808,11 @@ class VLAFlowMatching(nn.Module): timestep=current_timestep, ) - if self.config.rtc_config is not None and self.config.rtc_config.enabled: + rtc_config = getattr(self.config, "rtc_config", None) + if rtc_config is not None and rtc_config.enabled: inference_delay = kwargs.get("inference_delay") prev_chunk_left_over = kwargs.get("prev_chunk_left_over") - execution_horizon = kwargs.get("execution_horizon", self.config.rtc_config.execution_horizon) + execution_horizon = kwargs.get("execution_horizon", rtc_config.execution_horizon) v_t = self.rtc_processor.denoise_step( x_t=x_t, @@ -827,11 +830,8 @@ class VLAFlowMatching(nn.Module): time += dt # Record x_t after Euler step (other params are recorded in rtc_processor.denoise_step) - if ( - self.config.rtc_config is not None - and self.config.rtc_config.enabled - and correction is not None - ): + rtc_config = getattr(self.config, "rtc_config", None) + if rtc_config is not None and rtc_config.enabled and correction is not None: self.rtc_processor.track_debug(time=time, x_t=x_t) # Visualize x_t using plot_waypoints - accumulate all denoise steps @@ -902,7 +902,8 @@ class VLAFlowMatching(nn.Module): xt_name = "smolvla_x_t_denoise_steps.png" v_name = "smolvla_v_denoise_steps.png" - if self.config.rtc_config is not None and self.config.rtc_config.enabled: + rtc_config = getattr(self.config, "rtc_config", None) + if rtc_config is not None and rtc_config.enabled: xt_name = "smolvla_x_t_with_rtc_denoise_steps.png" v_name = "smolvla_v_with_rtc_denoise_steps.png" @@ -931,11 +932,8 @@ class VLAFlowMatching(nn.Module): # Plot ground truth on provided axes if available if use_provided_axes: prev_chunk_left_over = kwargs.get("prev_chunk_left_over") - if ( - prev_chunk_left_over is not None - and self.config.rtc_config is not None - and self.config.rtc_config.enabled - ): + rtc_config = getattr(self.config, "rtc_config", None) + if prev_chunk_left_over is not None and rtc_config is not None and rtc_config.enabled: plot_waypoints( viz_xt_axs, prev_chunk_left_over, start_from=0, color="red", label="Ground truth" )