diff --git a/src/lerobot/utils/model_profiling.py b/src/lerobot/utils/model_profiling.py index f7887ff12..7fe2a9cd1 100644 --- a/src/lerobot/utils/model_profiling.py +++ b/src/lerobot/utils/model_profiling.py @@ -333,7 +333,14 @@ def write_deterministic_forward_artifacts( # the dataloader (PR #3406). The dataset ships camera frames as uint8 for # faster transport, but policies like SmolVLA/xVLA run bilinear # interpolation on images which doesn't support Byte tensors. - for cam_key in dataset.meta.camera_keys: + camera_keys = tuple(getattr(getattr(dataset, "meta", None), "camera_keys", ()) or ()) + if not camera_keys: + camera_keys = tuple( + key + for key, value in reference_batch.items() + if key.startswith("observation.images.") and isinstance(value, torch.Tensor) + ) + for cam_key in camera_keys: if cam_key in reference_batch and reference_batch[cam_key].dtype == torch.uint8: reference_batch[cam_key] = reference_batch[cam_key].to(dtype=torch.float32) / 255.0 reference_batch = preprocessor(reference_batch) diff --git a/tests/test_model_profiling.py b/tests/test_model_profiling.py index 9952bc64f..bb2a06dfd 100644 --- a/tests/test_model_profiling.py +++ b/tests/test_model_profiling.py @@ -270,6 +270,30 @@ def test_deterministic_forward_artifacts_preserve_policy_mode(tmp_path): assert payload["outputs"]["loss"]["numel"] == 1 +def test_deterministic_forward_artifacts_infers_image_keys_without_dataset_meta(tmp_path): + class _ImagePolicy(torch.nn.Module): + def forward(self, batch): + image = batch["observation.images.front"] + assert image.dtype == torch.float32 + assert torch.all((0.0 <= image) & (image <= 1.0)) + return image.sum(), {"image": image} + + dataset = [{"observation.images.front": torch.tensor([[[0, 255]]], dtype=torch.uint8)}] + + mp.write_deterministic_forward_artifacts( + policy=_ImagePolicy(), + dataset=dataset, + batch_size=1, + preprocessor=lambda b: b, + output_dir=tmp_path, + device_type="cpu", + ) + + payload = json.loads((tmp_path / "deterministic_forward.json").read_text()) + assert payload["outputs"]["loss"]["numel"] == 1 + assert payload["outputs"]["output_dict"]["image"]["dtype"] == "torch.float32" + + def test_training_profiler_section_records_forward_backward_optimizer(tmp_path): profiler = mp.TrainingProfiler(mode="summary", output_dir=tmp_path, device=torch.device("cpu")) profiler.start()