more changes

This commit is contained in:
Jade Choghari
2025-11-16 11:39:17 +01:00
parent 589788e760
commit 818c75713b

View File

@@ -185,27 +185,28 @@ def rollout(
observation[f"observation.images.image2"] = observation[f"observation.images.image2"] * 255
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
observation = add_envs_task(env, observation)
inputs = processor([observation[f"observation.images.image"], observation[f"observation.images.image2"]], observation["task"])
# inputs = processor([observation[f"observation.images.image"], observation[f"observation.images.image2"]], observation["task"])
observation = preprocessor(observation)
inputs_1 = policy._build_model_inputs(observation)
for k in inputs.keys() & inputs_1.keys(): # intersection of keys
a = inputs[k].to("cuda")
b = inputs_1[k].to("cuda")
observation["domain_id"] = torch.tensor([int(3)], dtype=torch.long).to("cuda")
# inputs_1 = policy._build_model_inputs(observation)
# for k in inputs.keys() & inputs_1.keys(): # intersection of keys
# a = inputs[k].to("cuda")
# b = inputs_1[k].to("cuda")
print(f"\n🔎 Key: {k}")
# print(f"\n🔎 Key: {k}")
# Check shape
print(" shape:", a.shape, b.shape)
# # Check shape
# print(" shape:", a.shape, b.shape)
# Check if close
if torch.allclose(a, b, atol=1e-5, rtol=1e-5):
print(" ✔️ tensors are equal (allclose)")
else:
diff = torch.abs(a - b)
print(" ❌ tensors differ")
print(" max diff:", diff.max().item())
print(" mean diff:", diff.mean().item())
breakpoint()
# # Check if close
# if torch.allclose(a, b, atol=1e-5, rtol=1e-5):
# print(" ✔️ tensors are equal (allclose)")
# else:
# diff = torch.abs(a - b)
# print(" ❌ tensors differ")
# print(" max diff:", diff.max().item())
# print(" mean diff:", diff.mean().item())
# breakpoint()
with torch.inference_mode():
action = policy.select_action(observation).to("cpu").numpy()
# if len(action_queue) == 0:
@@ -230,7 +231,6 @@ def rollout(
# target_axis_1 = Rotate6D_to_AxisAngle(action_1[:, 3:9])
# target_act_1 = action_1[:, 9:10]
# action_numpy_1 = np.concatenate([target_eef_1, target_axis_1, target_act_1], axis=-1)
breakpoint()
# Convert to CPU / numpy.
# action_numpy: np.ndarray = action.to("cpu").numpy()