mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 20:31:25 +00:00
More things
This commit is contained in:
@@ -69,7 +69,7 @@ from tqdm import trange
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.eval import EvalPipelineConfig
|
||||
from lerobot.envs.factory import make_env
|
||||
from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation
|
||||
from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation, preprocess_observation1
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
@@ -125,6 +125,10 @@ def rollout(
|
||||
|
||||
# Reset the policy and environments.
|
||||
policy.reset()
|
||||
# added by jade
|
||||
# for k in list(policy.config.input_features.keys()):
|
||||
# if k.startswith("observation.image"):
|
||||
# policy.config.input_features["observation.images." + k.split("observation.", 1)[1]] = policy.config.input_features.pop(k)
|
||||
observation, info = env.reset(seed=seeds)
|
||||
if render_callback is not None:
|
||||
render_callback(env)
|
||||
@@ -149,6 +153,7 @@ def rollout(
|
||||
while not np.all(done) and step < max_steps:
|
||||
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
|
||||
observation = preprocess_observation(observation)
|
||||
# observation = preprocess_observation1(observation)
|
||||
if return_observations:
|
||||
all_observations.append(deepcopy(observation))
|
||||
|
||||
@@ -159,6 +164,26 @@ def rollout(
|
||||
# Infer "task" from attributes of environments.
|
||||
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
|
||||
observation = add_envs_task(env, observation)
|
||||
# breakpoint()
|
||||
# observation = {
|
||||
# k.replace("observation.images.", "observation.") if k.startswith("observation.images.") else k: v
|
||||
# for k, v in observation.items()
|
||||
# # }
|
||||
# if "observation.image" in observation:
|
||||
# observation["image"] = observation.pop("observation.image").to(
|
||||
# device, non_blocking=device.type == "cuda"
|
||||
# )
|
||||
|
||||
# if "observation.image2" in observation:
|
||||
# observation["wrist_image"] = observation.pop("observation.image2").to(
|
||||
# device, non_blocking=device.type == "cuda"
|
||||
# )
|
||||
|
||||
# if "observation.state" in observation:
|
||||
# observation["state"] = observation.pop("observation.state").to(
|
||||
# device, non_blocking=device.type == "cuda"
|
||||
# )
|
||||
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation)
|
||||
# Convert to CPU / numpy.
|
||||
@@ -489,12 +514,11 @@ def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotData
|
||||
print("Normalization layers recreated with dataset stats.")
|
||||
|
||||
|
||||
def load_smolvla(cfg, dataset_repo: str):
|
||||
def load_smolvla(cfg, dataset_repo: str, policy):
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
dataset = LeRobotDataset(dataset_repo, root='/raid/jade/.cache/huggingface/datasets/')
|
||||
policy = make_policy(cfg=cfg, ds_meta=dataset.meta)
|
||||
_inject_normalization_stats(policy=policy, dataset_meta=dataset.meta) # only needed if stats are missing
|
||||
return policy, dataset
|
||||
return policy.to("cuda"), dataset
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
@@ -505,7 +529,7 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||
#login to hf
|
||||
from huggingface_hub import login
|
||||
login()
|
||||
# login()
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
set_seed(cfg.seed)
|
||||
@@ -520,9 +544,10 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
cfg=cfg.policy,
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
# breakpoint()
|
||||
load_smolvla(cfg.policy, "physical-intelligence/libero")
|
||||
# breakpoint()
|
||||
breakpoint()
|
||||
# policy, _ = load_smolvla(cfg.policy, "physical-intelligence/libero", policy)
|
||||
# rename "image" -> "observation.image"
|
||||
|
||||
policy.eval()
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||
if cfg.env.multitask_eval:
|
||||
|
||||
Reference in New Issue
Block a user