More things

This commit is contained in:
Jade Choghari
2025-09-10 23:24:18 +02:00
parent 5c628f1700
commit aa40c8c813
12 changed files with 3670 additions and 48 deletions

View File

@@ -69,7 +69,7 @@ from tqdm import trange
from lerobot.configs import parser
from lerobot.configs.eval import EvalPipelineConfig
from lerobot.envs.factory import make_env
from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation
from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation, preprocess_observation1
from lerobot.policies.factory import make_policy
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import get_device_from_parameters
@@ -125,6 +125,10 @@ def rollout(
# Reset the policy and environments.
policy.reset()
# added by jade
# for k in list(policy.config.input_features.keys()):
# if k.startswith("observation.image"):
# policy.config.input_features["observation.images." + k.split("observation.", 1)[1]] = policy.config.input_features.pop(k)
observation, info = env.reset(seed=seeds)
if render_callback is not None:
render_callback(env)
@@ -149,6 +153,7 @@ def rollout(
while not np.all(done) and step < max_steps:
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
observation = preprocess_observation(observation)
# observation = preprocess_observation1(observation)
if return_observations:
all_observations.append(deepcopy(observation))
@@ -159,6 +164,26 @@ def rollout(
# Infer "task" from attributes of environments.
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
observation = add_envs_task(env, observation)
# breakpoint()
# observation = {
# k.replace("observation.images.", "observation.") if k.startswith("observation.images.") else k: v
# for k, v in observation.items()
# # }
# if "observation.image" in observation:
# observation["image"] = observation.pop("observation.image").to(
# device, non_blocking=device.type == "cuda"
# )
# if "observation.image2" in observation:
# observation["wrist_image"] = observation.pop("observation.image2").to(
# device, non_blocking=device.type == "cuda"
# )
# if "observation.state" in observation:
# observation["state"] = observation.pop("observation.state").to(
# device, non_blocking=device.type == "cuda"
# )
with torch.inference_mode():
action = policy.select_action(observation)
# Convert to CPU / numpy.
@@ -489,12 +514,11 @@ def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotData
print("Normalization layers recreated with dataset stats.")
def load_smolvla(cfg, dataset_repo: str):
def load_smolvla(cfg, dataset_repo: str, policy):
from lerobot.datasets.lerobot_dataset import LeRobotDataset
dataset = LeRobotDataset(dataset_repo, root='/raid/jade/.cache/huggingface/datasets/')
policy = make_policy(cfg=cfg, ds_meta=dataset.meta)
_inject_normalization_stats(policy=policy, dataset_meta=dataset.meta) # only needed if stats are missing
return policy, dataset
return policy.to("cuda"), dataset
@parser.wrap()
@@ -505,7 +529,7 @@ def eval_main(cfg: EvalPipelineConfig):
device = get_safe_torch_device(cfg.policy.device, log=True)
#login to hf
from huggingface_hub import login
login()
# login()
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_seed(cfg.seed)
@@ -520,9 +544,10 @@ def eval_main(cfg: EvalPipelineConfig):
cfg=cfg.policy,
env_cfg=cfg.env,
)
# breakpoint()
load_smolvla(cfg.policy, "physical-intelligence/libero")
# breakpoint()
breakpoint()
# policy, _ = load_smolvla(cfg.policy, "physical-intelligence/libero", policy)
# rename "image" -> "observation.image"
policy.eval()
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
if cfg.env.multitask_eval: