diff --git a/src/lerobot/policies/vqbet/modeling_vqbet.py b/src/lerobot/policies/vqbet/modeling_vqbet.py index 359b4fdb1..6d3976b79 100644 --- a/src/lerobot/policies/vqbet/modeling_vqbet.py +++ b/src/lerobot/policies/vqbet/modeling_vqbet.py @@ -467,8 +467,8 @@ class VQBeTHead(nn.Module): self.vqvae_model.optimized_steps += 1 # if we updated RVQ more than `n_vqvae_training_steps` steps, we freeze the RVQ part. if self.vqvae_model.optimized_steps >= n_vqvae_training_steps: - self.vqvae_model.discretized = torch.tensor(True) - self.vqvae_model.vq_layer.freeze_codebook = torch.tensor(True) + self.vqvae_model.discretized.fill_(True) + self.vqvae_model.vq_layer.freeze_codebook.fill_(True) print("Finished discretizing action data!") self.vqvae_model.eval() for param in self.vqvae_model.vq_layer.parameters(): diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index 1aae3fcc8..77a74d60e 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -42,6 +42,8 @@ from lerobot.policies.factory import ( make_pre_post_processors, ) from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig +from lerobot.policies.vqbet.modeling_vqbet import VQBeTHead from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE from lerobot.utils.random_utils import seeded_context from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats @@ -460,3 +462,45 @@ def test_act_temporal_ensembler(): assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max")) # Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error. torch.testing.assert_close(online_avg, offline_avg, rtol=1e-4, atol=1e-4) + + +def test_vqbet_discretize_keeps_buffers_on_device(): + """Regression test: VQBeTHead.discretize() must not move registered buffers off the model device. + + Previously, `self.vqvae_model.discretized = torch.tensor(True)` replaced the + registered buffer with a new CPU tensor, causing DDP to crash with: + RuntimeError: No backend type associated with device type cpu + The fix uses `.fill_(True)` to update in-place, preserving device placement. + """ + config = VQBeTConfig() + config.input_features = { + OBS_IMAGES: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 96, 96)), + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(6,)), + } + config.output_features = { + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(6,)), + } + # Tiny sizes for fast CPU/GPU execution. + config.n_vqvae_training_steps = 3 + config.vqvae_n_embed = 8 + config.vqvae_embedding_dim = 32 + config.vqvae_enc_hidden_dim = 32 + config.action_chunk_size = 2 + config.crop_shape = (84, 84) + + head = VQBeTHead(config).to(DEVICE) + vqvae = head.vqvae_model + + dummy_actions = torch.randn(4, config.action_chunk_size, config.action_feature.shape[0], device=DEVICE) + n_steps = config.n_vqvae_training_steps + for _ in range(n_steps): + head.discretize(n_steps, dummy_actions) + + assert vqvae.discretized.device.type == torch.device(DEVICE).type, ( + "vqvae_model.discretized was moved off the model device after discretize(). " + "Use .fill_(True) instead of = torch.tensor(True) to keep the buffer on device." + ) + assert vqvae.vq_layer.freeze_codebook.device.type == torch.device(DEVICE).type, ( + "vq_layer.freeze_codebook was moved off the model device after discretize(). " + "Use .fill_(True) instead of = torch.tensor(True) to keep the buffer on device." + )