mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 04:11:24 +00:00
Sub threading for multiprocessing
This commit is contained in:
@@ -19,6 +19,7 @@ from typing import Any, Callable, Optional, Sequence, TypedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import multiprocessing
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
@@ -135,10 +136,11 @@ class ReplayBuffer:
|
||||
self,
|
||||
capacity: int,
|
||||
device: str = "cuda:0",
|
||||
state_keys: Optional[Sequence[str]] = None,
|
||||
state_keys: Optional[list[str]] = None,
|
||||
image_augmentation_function: Optional[Callable] = None,
|
||||
use_drq: bool = True,
|
||||
storage_device: str = "cpu",
|
||||
use_shared_memory: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -150,16 +152,17 @@ class ReplayBuffer:
|
||||
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
|
||||
storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored when adding transitions to the buffer.
|
||||
Using "cpu" can help save GPU memory.
|
||||
use_shared_memory (bool): Whether to use shared memory for the buffer.
|
||||
"""
|
||||
self.capacity = capacity
|
||||
self.device = device
|
||||
self.storage_device = storage_device
|
||||
self.memory: list[Transition] = []
|
||||
self.memory: list[Transition] = torch.multiprocessing.Manager().list() if use_shared_memory else []
|
||||
self.position = 0
|
||||
|
||||
# If no state_keys provided, default to an empty list
|
||||
# (you can handle this differently if needed)
|
||||
self.state_keys = state_keys if state_keys is not None else []
|
||||
# Convert state_keys to a list for pickling
|
||||
self.state_keys = list(state_keys) if state_keys is not None else []
|
||||
|
||||
if image_augmentation_function is None:
|
||||
self.image_augmentation_function = functools.partial(random_shift, pad=4)
|
||||
self.use_drq = use_drq
|
||||
@@ -187,7 +190,7 @@ class ReplayBuffer:
|
||||
# }
|
||||
|
||||
if len(self.memory) < self.capacity:
|
||||
self.memory.append(None)
|
||||
self.memory.append({}) # Need to append something first for Manager().list()
|
||||
|
||||
# Create and store the Transition
|
||||
self.memory[self.position] = Transition(
|
||||
@@ -210,6 +213,7 @@ class ReplayBuffer:
|
||||
capacity: Optional[int] = None,
|
||||
action_mask: Optional[Sequence[int]] = None,
|
||||
action_delta: Optional[float] = None,
|
||||
use_shared_memory: bool = False,
|
||||
) -> "ReplayBuffer":
|
||||
"""
|
||||
Convert a LeRobotDataset into a ReplayBuffer.
|
||||
@@ -233,7 +237,7 @@ class ReplayBuffer:
|
||||
"The capacity of the ReplayBuffer must be greater than or equal to the length of the LeRobotDataset."
|
||||
)
|
||||
|
||||
replay_buffer = cls(capacity=capacity, device=device, state_keys=state_keys)
|
||||
replay_buffer = cls(capacity=capacity, device=device, state_keys=state_keys, use_shared_memory=use_shared_memory)
|
||||
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
|
||||
# Fill the replay buffer with the lerobot dataset transitions
|
||||
for data in list_transition:
|
||||
@@ -345,7 +349,19 @@ class ReplayBuffer:
|
||||
def sample(self, batch_size: int) -> BatchTransition:
|
||||
"""Sample a random batch of transitions and collate them into batched tensors."""
|
||||
batch_size = min(batch_size, len(self.memory))
|
||||
list_of_transitions = random.sample(self.memory, batch_size)
|
||||
# Different sampling approach for shared memory list vs regular list
|
||||
|
||||
list_of_transitions = random.sample(list(self.memory), batch_size)
|
||||
# if isinstance(self.memory, multiprocessing.managers.ListProxy):
|
||||
# # For shared memory list, we need to be careful about thread safety
|
||||
# with torch.multiprocessing.Lock():
|
||||
# # Get indices first to minimize lock time
|
||||
# indices = torch.randint(len(self.memory), size=(batch_size,)).tolist()
|
||||
# # Convert to list to avoid multiple proxy accesses
|
||||
# list_of_transitions = [self.memory[i] for i in indices]
|
||||
# else:
|
||||
# # For regular list, use faster random.sample
|
||||
# list_of_transitions = random.sample(self.memory, batch_size)
|
||||
|
||||
# -- Build batched states --
|
||||
batch_state = {}
|
||||
|
||||
Reference in New Issue
Block a user