Fix rtc_config attribute access in SmolVLA

Use getattr() to safely check for rtc_config attribute existence
instead of direct attribute access. This fixes AttributeError when
loading policies without rtc_config in their config.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Eugene Mironov
2025-11-03 17:53:37 +07:00
parent 0acdde4ae2
commit 08ff689a1e

View File

@@ -257,8 +257,9 @@ class SmolVLAPolicy(PreTrainedPolicy):
"""
self.rtc_processor = None
if self.config.rtc_config is not None and self.config.rtc_config.enabled:
self.rtc_processor = RTCProcessor(self.config.rtc_config)
rtc_config = getattr(self.config, "rtc_config", None)
if rtc_config is not None and rtc_config.enabled:
self.rtc_processor = RTCProcessor(rtc_config)
# In case of calling init_rtc_processor after the model is created
# We need to set the rtc_processor to the model
@@ -343,7 +344,8 @@ class SmolVLAPolicy(PreTrainedPolicy):
return len(self._queues[ACTION]) == 0
def _rtc_enabled(self) -> bool:
return self.config.rtc_config is not None and self.config.rtc_config.enabled
rtc_config = getattr(self.config, "rtc_config", None)
return rtc_config is not None and rtc_config.enabled
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
"""Do a full training forward pass to compute the loss"""
@@ -806,10 +808,11 @@ class VLAFlowMatching(nn.Module):
timestep=current_timestep,
)
if self.config.rtc_config is not None and self.config.rtc_config.enabled:
rtc_config = getattr(self.config, "rtc_config", None)
if rtc_config is not None and rtc_config.enabled:
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon", self.config.rtc_config.execution_horizon)
execution_horizon = kwargs.get("execution_horizon", rtc_config.execution_horizon)
v_t = self.rtc_processor.denoise_step(
x_t=x_t,
@@ -827,11 +830,8 @@ class VLAFlowMatching(nn.Module):
time += dt
# Record x_t after Euler step (other params are recorded in rtc_processor.denoise_step)
if (
self.config.rtc_config is not None
and self.config.rtc_config.enabled
and correction is not None
):
rtc_config = getattr(self.config, "rtc_config", None)
if rtc_config is not None and rtc_config.enabled and correction is not None:
self.rtc_processor.track_debug(time=time, x_t=x_t)
# Visualize x_t using plot_waypoints - accumulate all denoise steps
@@ -902,7 +902,8 @@ class VLAFlowMatching(nn.Module):
xt_name = "smolvla_x_t_denoise_steps.png"
v_name = "smolvla_v_denoise_steps.png"
if self.config.rtc_config is not None and self.config.rtc_config.enabled:
rtc_config = getattr(self.config, "rtc_config", None)
if rtc_config is not None and rtc_config.enabled:
xt_name = "smolvla_x_t_with_rtc_denoise_steps.png"
v_name = "smolvla_v_with_rtc_denoise_steps.png"
@@ -931,11 +932,8 @@ class VLAFlowMatching(nn.Module):
# Plot ground truth on provided axes if available
if use_provided_axes:
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
if (
prev_chunk_left_over is not None
and self.config.rtc_config is not None
and self.config.rtc_config.enabled
):
rtc_config = getattr(self.config, "rtc_config", None)
if prev_chunk_left_over is not None and rtc_config is not None and rtc_config.enabled:
plot_waypoints(
viz_xt_axs, prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
)