General fixes in code, removed delta action, fixed grasp penalty, added logic to put gripper reward in info

This commit is contained in:
Michel Aractingi
2025-04-09 17:04:43 +02:00
committed by Michel Aractingi
parent 02e1ed0bfb
commit 9fd4c21d4d
7 changed files with 75 additions and 65 deletions

View File

@@ -78,9 +78,7 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr
if isinstance(val, torch.Tensor):
transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
elif isinstance(val, (int, float, bool)):
transition["complementary_info"][key] = torch.tensor(
val, device=device, non_blocking=non_blocking
)
transition["complementary_info"][key] = torch.tensor(val, device=device)
else:
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")
return transition
@@ -505,7 +503,6 @@ class ReplayBuffer:
state_keys: Optional[Sequence[str]] = None,
capacity: Optional[int] = None,
action_mask: Optional[Sequence[int]] = None,
action_delta: Optional[float] = None,
image_augmentation_function: Optional[Callable] = None,
use_drq: bool = True,
storage_device: str = "cpu",
@@ -520,7 +517,6 @@ class ReplayBuffer:
state_keys (Optional[Sequence[str]]): The list of keys that appear in `state` and `next_state`.
capacity (Optional[int]): Buffer capacity. If None, uses dataset length.
action_mask (Optional[Sequence[int]]): Indices of action dimensions to keep.
action_delta (Optional[float]): Factor to divide actions by.
image_augmentation_function (Optional[Callable]): Function for image augmentation.
If None, uses default random shift with pad=4.
use_drq (bool): Whether to use DrQ image augmentation when sampling.
@@ -565,9 +561,6 @@ class ReplayBuffer:
else:
first_action = first_action[:, action_mask]
if action_delta is not None:
first_action = first_action / action_delta
# Get complementary info if available
first_complementary_info = None
if (
@@ -598,9 +591,6 @@ class ReplayBuffer:
else:
action = action[:, action_mask]
if action_delta is not None:
action = action / action_delta
replay_buffer.add(
state=data["state"],
action=action,