mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 10:51:35 +00:00
more changes
This commit is contained in:
@@ -185,27 +185,28 @@ def rollout(
|
||||
observation[f"observation.images.image2"] = observation[f"observation.images.image2"] * 255
|
||||
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
|
||||
observation = add_envs_task(env, observation)
|
||||
inputs = processor([observation[f"observation.images.image"], observation[f"observation.images.image2"]], observation["task"])
|
||||
# inputs = processor([observation[f"observation.images.image"], observation[f"observation.images.image2"]], observation["task"])
|
||||
observation = preprocessor(observation)
|
||||
inputs_1 = policy._build_model_inputs(observation)
|
||||
for k in inputs.keys() & inputs_1.keys(): # intersection of keys
|
||||
a = inputs[k].to("cuda")
|
||||
b = inputs_1[k].to("cuda")
|
||||
observation["domain_id"] = torch.tensor([int(3)], dtype=torch.long).to("cuda")
|
||||
# inputs_1 = policy._build_model_inputs(observation)
|
||||
# for k in inputs.keys() & inputs_1.keys(): # intersection of keys
|
||||
# a = inputs[k].to("cuda")
|
||||
# b = inputs_1[k].to("cuda")
|
||||
|
||||
print(f"\n🔎 Key: {k}")
|
||||
# print(f"\n🔎 Key: {k}")
|
||||
|
||||
# Check shape
|
||||
print(" shape:", a.shape, b.shape)
|
||||
# # Check shape
|
||||
# print(" shape:", a.shape, b.shape)
|
||||
|
||||
# Check if close
|
||||
if torch.allclose(a, b, atol=1e-5, rtol=1e-5):
|
||||
print(" ✔️ tensors are equal (allclose)")
|
||||
else:
|
||||
diff = torch.abs(a - b)
|
||||
print(" ❌ tensors differ")
|
||||
print(" max diff:", diff.max().item())
|
||||
print(" mean diff:", diff.mean().item())
|
||||
breakpoint()
|
||||
# # Check if close
|
||||
# if torch.allclose(a, b, atol=1e-5, rtol=1e-5):
|
||||
# print(" ✔️ tensors are equal (allclose)")
|
||||
# else:
|
||||
# diff = torch.abs(a - b)
|
||||
# print(" ❌ tensors differ")
|
||||
# print(" max diff:", diff.max().item())
|
||||
# print(" mean diff:", diff.mean().item())
|
||||
# breakpoint()
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation).to("cpu").numpy()
|
||||
# if len(action_queue) == 0:
|
||||
@@ -230,7 +231,6 @@ def rollout(
|
||||
# target_axis_1 = Rotate6D_to_AxisAngle(action_1[:, 3:9])
|
||||
# target_act_1 = action_1[:, 9:10]
|
||||
# action_numpy_1 = np.concatenate([target_eef_1, target_axis_1, target_act_1], axis=-1)
|
||||
breakpoint()
|
||||
|
||||
# Convert to CPU / numpy.
|
||||
# action_numpy: np.ndarray = action.to("cpu").numpy()
|
||||
|
||||
Reference in New Issue
Block a user