mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-01 03:11:29 +00:00
Fix rtc_config attribute access in SmolVLA
Use getattr() to safely check for rtc_config attribute existence instead of direct attribute access. This fixes AttributeError when loading policies without rtc_config in their config. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com> Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -257,8 +257,9 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
self.rtc_processor = None
|
||||
|
||||
if self.config.rtc_config is not None and self.config.rtc_config.enabled:
|
||||
self.rtc_processor = RTCProcessor(self.config.rtc_config)
|
||||
rtc_config = getattr(self.config, "rtc_config", None)
|
||||
if rtc_config is not None and rtc_config.enabled:
|
||||
self.rtc_processor = RTCProcessor(rtc_config)
|
||||
|
||||
# In case of calling init_rtc_processor after the model is created
|
||||
# We need to set the rtc_processor to the model
|
||||
@@ -343,7 +344,8 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
return len(self._queues[ACTION]) == 0
|
||||
|
||||
def _rtc_enabled(self) -> bool:
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
rtc_config = getattr(self.config, "rtc_config", None)
|
||||
return rtc_config is not None and rtc_config.enabled
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
|
||||
"""Do a full training forward pass to compute the loss"""
|
||||
@@ -806,10 +808,11 @@ class VLAFlowMatching(nn.Module):
|
||||
timestep=current_timestep,
|
||||
)
|
||||
|
||||
if self.config.rtc_config is not None and self.config.rtc_config.enabled:
|
||||
rtc_config = getattr(self.config, "rtc_config", None)
|
||||
if rtc_config is not None and rtc_config.enabled:
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon", self.config.rtc_config.execution_horizon)
|
||||
execution_horizon = kwargs.get("execution_horizon", rtc_config.execution_horizon)
|
||||
|
||||
v_t = self.rtc_processor.denoise_step(
|
||||
x_t=x_t,
|
||||
@@ -827,11 +830,8 @@ class VLAFlowMatching(nn.Module):
|
||||
time += dt
|
||||
|
||||
# Record x_t after Euler step (other params are recorded in rtc_processor.denoise_step)
|
||||
if (
|
||||
self.config.rtc_config is not None
|
||||
and self.config.rtc_config.enabled
|
||||
and correction is not None
|
||||
):
|
||||
rtc_config = getattr(self.config, "rtc_config", None)
|
||||
if rtc_config is not None and rtc_config.enabled and correction is not None:
|
||||
self.rtc_processor.track_debug(time=time, x_t=x_t)
|
||||
|
||||
# Visualize x_t using plot_waypoints - accumulate all denoise steps
|
||||
@@ -902,7 +902,8 @@ class VLAFlowMatching(nn.Module):
|
||||
xt_name = "smolvla_x_t_denoise_steps.png"
|
||||
v_name = "smolvla_v_denoise_steps.png"
|
||||
|
||||
if self.config.rtc_config is not None and self.config.rtc_config.enabled:
|
||||
rtc_config = getattr(self.config, "rtc_config", None)
|
||||
if rtc_config is not None and rtc_config.enabled:
|
||||
xt_name = "smolvla_x_t_with_rtc_denoise_steps.png"
|
||||
v_name = "smolvla_v_with_rtc_denoise_steps.png"
|
||||
|
||||
@@ -931,11 +932,8 @@ class VLAFlowMatching(nn.Module):
|
||||
# Plot ground truth on provided axes if available
|
||||
if use_provided_axes:
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
if (
|
||||
prev_chunk_left_over is not None
|
||||
and self.config.rtc_config is not None
|
||||
and self.config.rtc_config.enabled
|
||||
):
|
||||
rtc_config = getattr(self.config, "rtc_config", None)
|
||||
if prev_chunk_left_over is not None and rtc_config is not None and rtc_config.enabled:
|
||||
plot_waypoints(
|
||||
viz_xt_axs, prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user