Refactor RTC enabled checks to use _rtc_enabled helper

Add _rtc_enabled() helper method in VLAFlowMatching class to simplify
and clean up RTC enabled checks throughout the code. This reduces
code duplication and improves readability.

Changes:
- Add _rtc_enabled() method in VLAFlowMatching
- Replace verbose rtc_config checks with _rtc_enabled() calls
- Maintain exact same functionality with cleaner code

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Eugene Mironov
2025-11-03 18:08:39 +07:00
parent 455d347b49
commit 70548e55f0

View File

@@ -551,6 +551,9 @@ class VLAFlowMatching(nn.Module):
self.viz_fig = None
self.viz_axs = None
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def set_requires_grad(self):
for params in self.state_proj.parameters():
params.requires_grad = self.config.train_state_proj
@@ -806,7 +809,7 @@ class VLAFlowMatching(nn.Module):
timestep=current_timestep,
)
if self.config.rtc_config is not None and self.config.rtc_config.enabled:
if self._rtc_enabled():
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon", self.config.rtc_config.execution_horizon)
@@ -827,11 +830,7 @@ class VLAFlowMatching(nn.Module):
time += dt
# Record x_t after Euler step (other params are recorded in rtc_processor.denoise_step)
if (
self.config.rtc_config is not None
and self.config.rtc_config.enabled
and correction is not None
):
if self._rtc_enabled() and correction is not None:
self.rtc_processor.track_debug(time=time, x_t=x_t)
# Visualize x_t using plot_waypoints - accumulate all denoise steps
@@ -931,11 +930,7 @@ class VLAFlowMatching(nn.Module):
# Plot ground truth on provided axes if available
if use_provided_axes:
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
if (
prev_chunk_left_over is not None
and self.config.rtc_config is not None
and self.config.rtc_config.enabled
):
if prev_chunk_left_over is not None and self._rtc_enabled():
plot_waypoints(
viz_xt_axs, prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
)