[Port HIl-Serl] Refactor gym-manipulator (#1034)

This commit is contained in:
Michel Aractingi
2025-04-25 16:34:54 +02:00
committed by GitHub
parent a8da4a347e
commit bd4db8d747
13 changed files with 624 additions and 946 deletions

View File

@@ -15,26 +15,15 @@
# limitations under the License.
import functools
import io
import pickle # nosec B403: Safe usage of pickle
from contextlib import suppress
from typing import Any, Callable, Optional, Sequence, TypedDict
from typing import Callable, Optional, Sequence, TypedDict
import torch
import torch.nn.functional as F # noqa: N812
from tqdm import tqdm
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
class Transition(TypedDict):
state: dict[str, torch.Tensor]
action: torch.Tensor
reward: float
next_state: dict[str, torch.Tensor]
done: bool
truncated: bool
complementary_info: dict[str, torch.Tensor | float | int] | None = None
from lerobot.scripts.server.utils import Transition
class BatchTransition(TypedDict):
@@ -47,103 +36,6 @@ class BatchTransition(TypedDict):
complementary_info: dict[str, torch.Tensor | float | int] | None = None
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
device = torch.device(device)
non_blocking = device.type == "cuda"
# Move state tensors to device
transition["state"] = {
key: val.to(device, non_blocking=non_blocking) for key, val in transition["state"].items()
}
# Move action to device
transition["action"] = transition["action"].to(device, non_blocking=non_blocking)
# Move reward and done if they are tensors
if isinstance(transition["reward"], torch.Tensor):
transition["reward"] = transition["reward"].to(device, non_blocking=non_blocking)
if isinstance(transition["done"], torch.Tensor):
transition["done"] = transition["done"].to(device, non_blocking=non_blocking)
if isinstance(transition["truncated"], torch.Tensor):
transition["truncated"] = transition["truncated"].to(device, non_blocking=non_blocking)
# Move next_state tensors to device
transition["next_state"] = {
key: val.to(device, non_blocking=non_blocking) for key, val in transition["next_state"].items()
}
# Move complementary_info tensors if present
if transition.get("complementary_info") is not None:
for key, val in transition["complementary_info"].items():
if isinstance(val, torch.Tensor):
transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
elif isinstance(val, (int, float, bool)):
transition["complementary_info"][key] = torch.tensor(val, device=device)
else:
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")
return transition
def move_state_dict_to_device(state_dict, device="cpu"):
"""
Recursively move all tensors in a (potentially) nested
dict/list/tuple structure to the CPU.
"""
if isinstance(state_dict, torch.Tensor):
return state_dict.to(device)
elif isinstance(state_dict, dict):
return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()}
elif isinstance(state_dict, list):
return [move_state_dict_to_device(v, device=device) for v in state_dict]
elif isinstance(state_dict, tuple):
return tuple(move_state_dict_to_device(v, device=device) for v in state_dict)
else:
return state_dict
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
"""Convert model state dict to flat array for transmission"""
buffer = io.BytesIO()
torch.save(state_dict, buffer)
return buffer.getvalue()
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
buffer = io.BytesIO(buffer)
buffer.seek(0)
return torch.load(buffer) # nosec B614: Safe usage of torch.load
def python_object_to_bytes(python_object: Any) -> bytes:
return pickle.dumps(python_object)
def bytes_to_python_object(buffer: bytes) -> Any:
buffer = io.BytesIO(buffer)
buffer.seek(0)
obj = pickle.load(buffer) # nosec B301: Safe usage of pickle.load
# Add validation checks here
return obj
def bytes_to_transitions(buffer: bytes) -> list[Transition]:
buffer = io.BytesIO(buffer)
buffer.seek(0)
transitions = torch.load(buffer) # nosec B614: Safe usage of torch.load
# Add validation checks here
return transitions
def transitions_to_bytes(transitions: list[Transition]) -> bytes:
buffer = io.BytesIO()
torch.save(transitions, buffer)
return buffer.getvalue()
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
"""
Perform a per-image random crop over a batch of images in a vectorized way.
@@ -514,7 +406,6 @@ class ReplayBuffer:
device: str = "cuda:0",
state_keys: Optional[Sequence[str]] = None,
capacity: Optional[int] = None,
action_mask: Optional[Sequence[int]] = None,
image_augmentation_function: Optional[Callable] = None,
use_drq: bool = True,
storage_device: str = "cpu",
@@ -566,13 +457,6 @@ class ReplayBuffer:
first_state = {k: v.to(device) for k, v in first_transition["state"].items()}
first_action = first_transition["action"].to(device)
# Apply action mask/delta if needed
if action_mask is not None:
if first_action.dim() == 1:
first_action = first_action[action_mask]
else:
first_action = first_action[:, action_mask]
# Get complementary info if available
first_complementary_info = None
if (
@@ -597,8 +481,6 @@ class ReplayBuffer:
data[k] = v.to(storage_device)
action = data["action"]
if action_mask is not None:
action = action[action_mask] if action.dim() == 1 else action[:, action_mask]
replay_buffer.add(
state=data["state"],