fix(profiling): handle datasets without metadata in forward artifacts

This commit is contained in:
Pepijn
2026-04-21 18:04:35 +02:00
parent ce9bfa754d
commit fe78f8fee9
2 changed files with 32 additions and 1 deletions

View File

@@ -270,6 +270,30 @@ def test_deterministic_forward_artifacts_preserve_policy_mode(tmp_path):
assert payload["outputs"]["loss"]["numel"] == 1
def test_deterministic_forward_artifacts_infers_image_keys_without_dataset_meta(tmp_path):
class _ImagePolicy(torch.nn.Module):
def forward(self, batch):
image = batch["observation.images.front"]
assert image.dtype == torch.float32
assert torch.all((0.0 <= image) & (image <= 1.0))
return image.sum(), {"image": image}
dataset = [{"observation.images.front": torch.tensor([[[0, 255]]], dtype=torch.uint8)}]
mp.write_deterministic_forward_artifacts(
policy=_ImagePolicy(),
dataset=dataset,
batch_size=1,
preprocessor=lambda b: b,
output_dir=tmp_path,
device_type="cpu",
)
payload = json.loads((tmp_path / "deterministic_forward.json").read_text())
assert payload["outputs"]["loss"]["numel"] == 1
assert payload["outputs"]["output_dict"]["image"]["dtype"] == "torch.float32"
def test_training_profiler_section_records_forward_backward_optimizer(tmp_path):
profiler = mp.TrainingProfiler(mode="summary", output_dir=tmp_path, device=torch.device("cpu"))
profiler.start()