mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 12:21:27 +00:00
fix(profiling): handle datasets without metadata in forward artifacts
This commit is contained in:
@@ -270,6 +270,30 @@ def test_deterministic_forward_artifacts_preserve_policy_mode(tmp_path):
|
||||
assert payload["outputs"]["loss"]["numel"] == 1
|
||||
|
||||
|
||||
def test_deterministic_forward_artifacts_infers_image_keys_without_dataset_meta(tmp_path):
|
||||
class _ImagePolicy(torch.nn.Module):
|
||||
def forward(self, batch):
|
||||
image = batch["observation.images.front"]
|
||||
assert image.dtype == torch.float32
|
||||
assert torch.all((0.0 <= image) & (image <= 1.0))
|
||||
return image.sum(), {"image": image}
|
||||
|
||||
dataset = [{"observation.images.front": torch.tensor([[[0, 255]]], dtype=torch.uint8)}]
|
||||
|
||||
mp.write_deterministic_forward_artifacts(
|
||||
policy=_ImagePolicy(),
|
||||
dataset=dataset,
|
||||
batch_size=1,
|
||||
preprocessor=lambda b: b,
|
||||
output_dir=tmp_path,
|
||||
device_type="cpu",
|
||||
)
|
||||
|
||||
payload = json.loads((tmp_path / "deterministic_forward.json").read_text())
|
||||
assert payload["outputs"]["loss"]["numel"] == 1
|
||||
assert payload["outputs"]["output_dict"]["image"]["dtype"] == "torch.float32"
|
||||
|
||||
|
||||
def test_training_profiler_section_records_forward_backward_optimizer(tmp_path):
|
||||
profiler = mp.TrainingProfiler(mode="summary", output_dir=tmp_path, device=torch.device("cpu"))
|
||||
profiler.start()
|
||||
|
||||
Reference in New Issue
Block a user