mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 04:11:24 +00:00
fix(profiling): handle datasets without metadata in forward artifacts
This commit is contained in:
@@ -333,7 +333,14 @@ def write_deterministic_forward_artifacts(
|
||||
# the dataloader (PR #3406). The dataset ships camera frames as uint8 for
|
||||
# faster transport, but policies like SmolVLA/xVLA run bilinear
|
||||
# interpolation on images which doesn't support Byte tensors.
|
||||
for cam_key in dataset.meta.camera_keys:
|
||||
camera_keys = tuple(getattr(getattr(dataset, "meta", None), "camera_keys", ()) or ())
|
||||
if not camera_keys:
|
||||
camera_keys = tuple(
|
||||
key
|
||||
for key, value in reference_batch.items()
|
||||
if key.startswith("observation.images.") and isinstance(value, torch.Tensor)
|
||||
)
|
||||
for cam_key in camera_keys:
|
||||
if cam_key in reference_batch and reference_batch[cam_key].dtype == torch.uint8:
|
||||
reference_batch[cam_key] = reference_batch[cam_key].to(dtype=torch.float32) / 255.0
|
||||
reference_batch = preprocessor(reference_batch)
|
||||
|
||||
@@ -270,6 +270,30 @@ def test_deterministic_forward_artifacts_preserve_policy_mode(tmp_path):
|
||||
assert payload["outputs"]["loss"]["numel"] == 1
|
||||
|
||||
|
||||
def test_deterministic_forward_artifacts_infers_image_keys_without_dataset_meta(tmp_path):
|
||||
class _ImagePolicy(torch.nn.Module):
|
||||
def forward(self, batch):
|
||||
image = batch["observation.images.front"]
|
||||
assert image.dtype == torch.float32
|
||||
assert torch.all((0.0 <= image) & (image <= 1.0))
|
||||
return image.sum(), {"image": image}
|
||||
|
||||
dataset = [{"observation.images.front": torch.tensor([[[0, 255]]], dtype=torch.uint8)}]
|
||||
|
||||
mp.write_deterministic_forward_artifacts(
|
||||
policy=_ImagePolicy(),
|
||||
dataset=dataset,
|
||||
batch_size=1,
|
||||
preprocessor=lambda b: b,
|
||||
output_dir=tmp_path,
|
||||
device_type="cpu",
|
||||
)
|
||||
|
||||
payload = json.loads((tmp_path / "deterministic_forward.json").read_text())
|
||||
assert payload["outputs"]["loss"]["numel"] == 1
|
||||
assert payload["outputs"]["output_dict"]["image"]["dtype"] == "torch.float32"
|
||||
|
||||
|
||||
def test_training_profiler_section_records_forward_backward_optimizer(tmp_path):
|
||||
profiler = mp.TrainingProfiler(mode="summary", output_dir=tmp_path, device=torch.device("cpu"))
|
||||
profiler.start()
|
||||
|
||||
Reference in New Issue
Block a user