fix renaming issues with cams

This commit is contained in:
Jade Choghari (jchoghar)
2025-08-20 06:55:05 -04:00
parent 5d25f5bd40
commit cc46497f4c
9 changed files with 59 additions and 49 deletions

View File

@@ -146,7 +146,8 @@ def rollout(
check_env_attributes_and_types(env)
while not np.all(done) and step < max_steps:
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
observation = preprocess_observation(observation)
# observation = preprocess_observation(observation)
observation = preprocess_observation(observation, cfg=policy.config)
if return_observations:
all_observations.append(deepcopy(observation))
@@ -159,7 +160,6 @@ def rollout(
observation = add_envs_task(env, observation)
with torch.inference_mode():
action = policy.select_action(observation)
observation["observation.images.image"]
# Convert to CPU / numpy.
action = action.to("cpu").numpy()
assert action.ndim == 2, "Action dimensions should be (batch, action_dim)"
@@ -198,7 +198,7 @@ def rollout(
# Track the final observation.
if return_observations:
observation = preprocess_observation(observation)
observation = preprocess_observation(observation, cfg=policy.config)
all_observations.append(deepcopy(observation))
# Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors.

View File

@@ -269,6 +269,7 @@ def train(cfg: TrainPipelineConfig):
continue # Skip the overall stats since we already printed it
print(f"\nAggregated Metrics for {task_group}:")
print(task_group_info["aggregated"])
breakpoint()
else:
print("START EVAL")
eval_info = eval_policy(
@@ -279,6 +280,7 @@ def train(cfg: TrainPipelineConfig):
max_episodes_rendered=4,
start_seed=cfg.seed,
)
breakpoint()
aggregated = eval_info["aggregated"]
print("END EVAL")