fix(profiling): handle datasets without metadata in forward artifacts

This commit is contained in:
Pepijn
2026-04-21 18:04:35 +02:00
parent ce9bfa754d
commit fe78f8fee9
2 changed files with 32 additions and 1 deletions

View File

@@ -333,7 +333,14 @@ def write_deterministic_forward_artifacts(
# the dataloader (PR #3406). The dataset ships camera frames as uint8 for
# faster transport, but policies like SmolVLA/xVLA run bilinear
# interpolation on images which doesn't support Byte tensors.
for cam_key in dataset.meta.camera_keys:
camera_keys = tuple(getattr(getattr(dataset, "meta", None), "camera_keys", ()) or ())
if not camera_keys:
camera_keys = tuple(
key
for key, value in reference_batch.items()
if key.startswith("observation.images.") and isinstance(value, torch.Tensor)
)
for cam_key in camera_keys:
if cam_key in reference_batch and reference_batch[cam_key].dtype == torch.uint8:
reference_batch[cam_key] = reference_batch[cam_key].to(dtype=torch.float32) / 255.0
reference_batch = preprocessor(reference_batch)

View File

@@ -270,6 +270,30 @@ def test_deterministic_forward_artifacts_preserve_policy_mode(tmp_path):
assert payload["outputs"]["loss"]["numel"] == 1
def test_deterministic_forward_artifacts_infers_image_keys_without_dataset_meta(tmp_path):
class _ImagePolicy(torch.nn.Module):
def forward(self, batch):
image = batch["observation.images.front"]
assert image.dtype == torch.float32
assert torch.all((0.0 <= image) & (image <= 1.0))
return image.sum(), {"image": image}
dataset = [{"observation.images.front": torch.tensor([[[0, 255]]], dtype=torch.uint8)}]
mp.write_deterministic_forward_artifacts(
policy=_ImagePolicy(),
dataset=dataset,
batch_size=1,
preprocessor=lambda b: b,
output_dir=tmp_path,
device_type="cpu",
)
payload = json.loads((tmp_path / "deterministic_forward.json").read_text())
assert payload["outputs"]["loss"]["numel"] == 1
assert payload["outputs"]["output_dict"]["image"]["dtype"] == "torch.float32"
def test_training_profiler_section_records_forward_backward_optimizer(tmp_path):
profiler = mp.TrainingProfiler(mode="summary", output_dir=tmp_path, device=torch.device("cpu"))
profiler.start()