From a2f72e42f60a2e9409bb4450cd67a011e7a8aaa3 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Mon, 20 Apr 2026 23:33:24 +0200 Subject: [PATCH] fix(profiling): convert uint8 images to float32 in deterministic forward MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mirror the uint8 → float32/255 conversion the train loop applies after the dataloader (PR #3406). The reference batch in `write_deterministic_forward_artifacts` skipped this step because it calls `preprocessor(default_collate(...))` directly, which caused SmolVLA and xVLA to crash with: NotImplementedError: "upsample_bilinear2d_out_frame" not implemented for 'Byte' inside their `resize_with_pad` → `F.interpolate(..., mode="bilinear")` path. Other policies dodged it because their image-prep casts first. Made-with: Cursor --- src/lerobot/utils/model_profiling.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/lerobot/utils/model_profiling.py b/src/lerobot/utils/model_profiling.py index 885a2ecf0..f7887ff12 100644 --- a/src/lerobot/utils/model_profiling.py +++ b/src/lerobot/utils/model_profiling.py @@ -328,7 +328,15 @@ def write_deterministic_forward_artifacts( if len(dataset) == 0: raise ValueError("Cannot build a reference batch from an empty dataset.") indices = [i % len(dataset) for i in range(batch_size)] - reference_batch = preprocessor(default_collate([dataset[i] for i in indices])) + reference_batch = default_collate([dataset[i] for i in indices]) + # Mirror the uint8 → float32/255 conversion the train loop applies after + # the dataloader (PR #3406). The dataset ships camera frames as uint8 for + # faster transport, but policies like SmolVLA/xVLA run bilinear + # interpolation on images which doesn't support Byte tensors. + for cam_key in dataset.meta.camera_keys: + if cam_key in reference_batch and reference_batch[cam_key].dtype == torch.uint8: + reference_batch[cam_key] = reference_batch[cam_key].to(dtype=torch.float32) / 255.0 + reference_batch = preprocessor(reference_batch) activities = [torch.profiler.ProfilerActivity.CPU] if device_type == "cuda":