From 70548e55f054437f2fbb7ade152f6caebe9954db Mon Sep 17 00:00:00 2001 From: Eugene Mironov Date: Mon, 3 Nov 2025 18:08:39 +0700 Subject: [PATCH] Refactor RTC enabled checks to use _rtc_enabled helper MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add _rtc_enabled() helper method in VLAFlowMatching class to simplify and clean up RTC enabled checks throughout the code. This reduces code duplication and improves readability. Changes: - Add _rtc_enabled() method in VLAFlowMatching - Replace verbose rtc_config checks with _rtc_enabled() calls - Maintain exact same functionality with cleaner code 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Alexander Soare Co-Authored-By: Claude --- .../policies/smolvla/modeling_smolvla.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index dd49a45f7..a9cdbf7a0 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -551,6 +551,9 @@ class VLAFlowMatching(nn.Module): self.viz_fig = None self.viz_axs = None + def _rtc_enabled(self): + return self.config.rtc_config is not None and self.config.rtc_config.enabled + def set_requires_grad(self): for params in self.state_proj.parameters(): params.requires_grad = self.config.train_state_proj @@ -806,7 +809,7 @@ class VLAFlowMatching(nn.Module): timestep=current_timestep, ) - if self.config.rtc_config is not None and self.config.rtc_config.enabled: + if self._rtc_enabled(): inference_delay = kwargs.get("inference_delay") prev_chunk_left_over = kwargs.get("prev_chunk_left_over") execution_horizon = kwargs.get("execution_horizon", self.config.rtc_config.execution_horizon) @@ -827,11 +830,7 @@ class VLAFlowMatching(nn.Module): time += dt # Record x_t after Euler step (other params are recorded in rtc_processor.denoise_step) - if ( - self.config.rtc_config is not None - and self.config.rtc_config.enabled - and correction is not None - ): + if self._rtc_enabled() and correction is not None: self.rtc_processor.track_debug(time=time, x_t=x_t) # Visualize x_t using plot_waypoints - accumulate all denoise steps @@ -931,11 +930,7 @@ class VLAFlowMatching(nn.Module): # Plot ground truth on provided axes if available if use_provided_axes: prev_chunk_left_over = kwargs.get("prev_chunk_left_over") - if ( - prev_chunk_left_over is not None - and self.config.rtc_config is not None - and self.config.rtc_config.enabled - ): + if prev_chunk_left_over is not None and self._rtc_enabled(): plot_waypoints( viz_xt_axs, prev_chunk_left_over, start_from=0, color="red", label="Ground truth" )