mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-01 03:11:29 +00:00
219 lines
6.9 KiB
Python
219 lines
6.9 KiB
Python
|
|
"""Tests for policy.path support in YAML config files (issue #2957)."""
|
||
|
|
|
||
|
|
import json
|
||
|
|
import tempfile
|
||
|
|
|
||
|
|
import yaml
|
||
|
|
|
||
|
|
from lerobot.configs.parser import (
|
||
|
|
_config_path_args,
|
||
|
|
_config_yaml_overrides,
|
||
|
|
_flatten_to_cli_args,
|
||
|
|
extract_path_fields_from_config,
|
||
|
|
get_path_arg,
|
||
|
|
get_yaml_overrides,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def test_extract_path_fields_from_yaml():
|
||
|
|
"""Test that policy.path is extracted from a YAML config and removed."""
|
||
|
|
config = {
|
||
|
|
"dataset": {"repo_id": "lerobot/pusht"},
|
||
|
|
"policy": {"type": "smolvla", "path": "lerobot/smolvla_base", "push_to_hub": False},
|
||
|
|
}
|
||
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||
|
|
yaml.dump(config, f)
|
||
|
|
config_path = f.name
|
||
|
|
|
||
|
|
_config_path_args.clear()
|
||
|
|
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
|
||
|
|
|
||
|
|
# Path should be extracted and stored
|
||
|
|
assert _config_path_args["policy"] == "lerobot/smolvla_base"
|
||
|
|
|
||
|
|
# Cleaned config should not have the path field
|
||
|
|
with open(cleaned_path) as f:
|
||
|
|
cleaned = yaml.safe_load(f)
|
||
|
|
assert "path" not in cleaned["policy"]
|
||
|
|
assert cleaned["policy"]["type"] == "smolvla"
|
||
|
|
assert cleaned["policy"]["push_to_hub"] is False
|
||
|
|
|
||
|
|
# Original dataset should be untouched
|
||
|
|
assert cleaned["dataset"]["repo_id"] == "lerobot/pusht"
|
||
|
|
|
||
|
|
_config_path_args.clear()
|
||
|
|
|
||
|
|
|
||
|
|
def test_extract_path_fields_from_json():
|
||
|
|
"""Test that policy.path is extracted from a JSON config."""
|
||
|
|
config = {
|
||
|
|
"policy": {"type": "act", "path": "some/local/path"},
|
||
|
|
}
|
||
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||
|
|
json.dump(config, f)
|
||
|
|
config_path = f.name
|
||
|
|
|
||
|
|
_config_path_args.clear()
|
||
|
|
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
|
||
|
|
|
||
|
|
assert _config_path_args["policy"] == "some/local/path"
|
||
|
|
|
||
|
|
with open(cleaned_path) as f:
|
||
|
|
cleaned = json.load(f)
|
||
|
|
assert "path" not in cleaned["policy"]
|
||
|
|
|
||
|
|
_config_path_args.clear()
|
||
|
|
|
||
|
|
|
||
|
|
def test_extract_no_path_returns_original():
|
||
|
|
"""Test that configs without path fields are returned unchanged."""
|
||
|
|
config = {
|
||
|
|
"dataset": {"repo_id": "lerobot/pusht"},
|
||
|
|
"policy": {"type": "smolvla"},
|
||
|
|
}
|
||
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||
|
|
yaml.dump(config, f)
|
||
|
|
config_path = f.name
|
||
|
|
|
||
|
|
_config_path_args.clear()
|
||
|
|
result = extract_path_fields_from_config(config_path, ["policy"])
|
||
|
|
|
||
|
|
assert result == config_path
|
||
|
|
assert "policy" not in _config_path_args
|
||
|
|
|
||
|
|
_config_path_args.clear()
|
||
|
|
|
||
|
|
|
||
|
|
def test_extract_removes_empty_field():
|
||
|
|
"""Test that the field dict is removed entirely if path was the only key."""
|
||
|
|
config = {
|
||
|
|
"dataset": {"repo_id": "lerobot/pusht"},
|
||
|
|
"policy": {"path": "lerobot/smolvla_base"},
|
||
|
|
}
|
||
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||
|
|
yaml.dump(config, f)
|
||
|
|
config_path = f.name
|
||
|
|
|
||
|
|
_config_path_args.clear()
|
||
|
|
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
|
||
|
|
|
||
|
|
assert _config_path_args["policy"] == "lerobot/smolvla_base"
|
||
|
|
|
||
|
|
with open(cleaned_path) as f:
|
||
|
|
cleaned = yaml.safe_load(f)
|
||
|
|
assert "policy" not in cleaned
|
||
|
|
|
||
|
|
_config_path_args.clear()
|
||
|
|
|
||
|
|
|
||
|
|
def test_get_path_arg_fallback():
|
||
|
|
"""Test that get_path_arg falls back to _config_path_args when CLI has no path."""
|
||
|
|
_config_path_args.clear()
|
||
|
|
_config_path_args["policy"] = "lerobot/smolvla_base"
|
||
|
|
|
||
|
|
# No CLI args with --policy.path
|
||
|
|
result = get_path_arg("policy", args=[])
|
||
|
|
assert result == "lerobot/smolvla_base"
|
||
|
|
|
||
|
|
_config_path_args.clear()
|
||
|
|
|
||
|
|
|
||
|
|
def test_get_path_arg_cli_takes_precedence():
|
||
|
|
"""Test that CLI --policy.path takes precedence over YAML config path."""
|
||
|
|
_config_path_args.clear()
|
||
|
|
_config_path_args["policy"] = "yaml/path"
|
||
|
|
|
||
|
|
result = get_path_arg("policy", args=["--policy.path=cli/path"])
|
||
|
|
assert result == "cli/path"
|
||
|
|
|
||
|
|
_config_path_args.clear()
|
||
|
|
|
||
|
|
|
||
|
|
def test_yaml_overrides_captured():
|
||
|
|
"""Test that non-path policy fields are captured as CLI-style overrides."""
|
||
|
|
config = {
|
||
|
|
"policy": {"path": "lerobot/smolvla_base", "lr": 1e-4, "batch_size": 32},
|
||
|
|
}
|
||
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||
|
|
yaml.dump(config, f)
|
||
|
|
config_path = f.name
|
||
|
|
|
||
|
|
_config_path_args.clear()
|
||
|
|
_config_yaml_overrides.clear()
|
||
|
|
extract_path_fields_from_config(config_path, ["policy"])
|
||
|
|
|
||
|
|
overrides = get_yaml_overrides("policy")
|
||
|
|
assert "--lr=0.0001" in overrides or any("lr=" in o for o in overrides)
|
||
|
|
assert any("batch_size=32" in o for o in overrides)
|
||
|
|
|
||
|
|
_config_path_args.clear()
|
||
|
|
_config_yaml_overrides.clear()
|
||
|
|
|
||
|
|
|
||
|
|
def test_yaml_overrides_excludes_type_and_path():
|
||
|
|
"""Test that type and path fields are not included in YAML overrides."""
|
||
|
|
config = {
|
||
|
|
"policy": {"path": "lerobot/smolvla_base", "type": "smolvla", "lr": 5e-5},
|
||
|
|
}
|
||
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||
|
|
yaml.dump(config, f)
|
||
|
|
config_path = f.name
|
||
|
|
|
||
|
|
_config_path_args.clear()
|
||
|
|
_config_yaml_overrides.clear()
|
||
|
|
extract_path_fields_from_config(config_path, ["policy"])
|
||
|
|
|
||
|
|
overrides = get_yaml_overrides("policy")
|
||
|
|
assert not any("path=" in o for o in overrides)
|
||
|
|
assert not any("type=" in o for o in overrides)
|
||
|
|
assert any("lr=" in o for o in overrides)
|
||
|
|
|
||
|
|
_config_path_args.clear()
|
||
|
|
_config_yaml_overrides.clear()
|
||
|
|
|
||
|
|
|
||
|
|
def test_get_yaml_overrides_empty_when_path_only():
|
||
|
|
"""Test that get_yaml_overrides returns [] when policy had only a path field."""
|
||
|
|
config = {
|
||
|
|
"policy": {"path": "lerobot/smolvla_base"},
|
||
|
|
}
|
||
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||
|
|
yaml.dump(config, f)
|
||
|
|
config_path = f.name
|
||
|
|
|
||
|
|
_config_path_args.clear()
|
||
|
|
_config_yaml_overrides.clear()
|
||
|
|
extract_path_fields_from_config(config_path, ["policy"])
|
||
|
|
|
||
|
|
assert get_yaml_overrides("policy") == []
|
||
|
|
|
||
|
|
_config_path_args.clear()
|
||
|
|
_config_yaml_overrides.clear()
|
||
|
|
|
||
|
|
|
||
|
|
def test_flatten_bool_values():
|
||
|
|
"""Test that boolean values are serialized as lowercase strings for draccus."""
|
||
|
|
d = {"push_to_hub": True, "use_rabc": False, "lr": 0.001, "name": "test"}
|
||
|
|
args = _flatten_to_cli_args(d)
|
||
|
|
assert "--push_to_hub=true" in args
|
||
|
|
assert "--use_rabc=false" in args
|
||
|
|
assert "--lr=0.001" in args
|
||
|
|
assert "--name=test" in args
|
||
|
|
|
||
|
|
|
||
|
|
def test_flatten_none_values_skipped():
|
||
|
|
"""Test that None values are not included in flattened args."""
|
||
|
|
d = {"lr": 0.001, "path_override": None, "name": "test"}
|
||
|
|
args = _flatten_to_cli_args(d)
|
||
|
|
assert any("lr=" in a for a in args)
|
||
|
|
assert any("name=" in a for a in args)
|
||
|
|
assert not any("path_override" in a for a in args)
|
||
|
|
|
||
|
|
|
||
|
|
def test_flatten_nested_with_bools():
|
||
|
|
"""Test that bools in nested dicts are handled correctly."""
|
||
|
|
d = {"optimizer": {"use_warmup": True, "lr": 0.01}}
|
||
|
|
args = _flatten_to_cli_args(d)
|
||
|
|
assert "--optimizer.use_warmup=true" in args
|
||
|
|
assert "--optimizer.lr=0.01" in args
|