Sub threading for multiprocessing

This commit is contained in:
Michel Aractingi
2025-02-20 17:21:55 +00:00
parent ff47c0b0d3
commit a9e912a05c
5 changed files with 870 additions and 36 deletions

View File

@@ -19,6 +19,7 @@ from typing import Any, Callable, Optional, Sequence, TypedDict
import torch
import torch.nn.functional as F # noqa: N812
import multiprocessing
from tqdm import tqdm
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
@@ -135,10 +136,11 @@ class ReplayBuffer:
self,
capacity: int,
device: str = "cuda:0",
state_keys: Optional[Sequence[str]] = None,
state_keys: Optional[list[str]] = None,
image_augmentation_function: Optional[Callable] = None,
use_drq: bool = True,
storage_device: str = "cpu",
use_shared_memory: bool = False,
):
"""
Args:
@@ -150,16 +152,17 @@ class ReplayBuffer:
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored when adding transitions to the buffer.
Using "cpu" can help save GPU memory.
use_shared_memory (bool): Whether to use shared memory for the buffer.
"""
self.capacity = capacity
self.device = device
self.storage_device = storage_device
self.memory: list[Transition] = []
self.memory: list[Transition] = torch.multiprocessing.Manager().list() if use_shared_memory else []
self.position = 0
# If no state_keys provided, default to an empty list
# (you can handle this differently if needed)
self.state_keys = state_keys if state_keys is not None else []
# Convert state_keys to a list for pickling
self.state_keys = list(state_keys) if state_keys is not None else []
if image_augmentation_function is None:
self.image_augmentation_function = functools.partial(random_shift, pad=4)
self.use_drq = use_drq
@@ -187,7 +190,7 @@ class ReplayBuffer:
# }
if len(self.memory) < self.capacity:
self.memory.append(None)
self.memory.append({}) # Need to append something first for Manager().list()
# Create and store the Transition
self.memory[self.position] = Transition(
@@ -210,6 +213,7 @@ class ReplayBuffer:
capacity: Optional[int] = None,
action_mask: Optional[Sequence[int]] = None,
action_delta: Optional[float] = None,
use_shared_memory: bool = False,
) -> "ReplayBuffer":
"""
Convert a LeRobotDataset into a ReplayBuffer.
@@ -233,7 +237,7 @@ class ReplayBuffer:
"The capacity of the ReplayBuffer must be greater than or equal to the length of the LeRobotDataset."
)
replay_buffer = cls(capacity=capacity, device=device, state_keys=state_keys)
replay_buffer = cls(capacity=capacity, device=device, state_keys=state_keys, use_shared_memory=use_shared_memory)
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
# Fill the replay buffer with the lerobot dataset transitions
for data in list_transition:
@@ -345,7 +349,19 @@ class ReplayBuffer:
def sample(self, batch_size: int) -> BatchTransition:
"""Sample a random batch of transitions and collate them into batched tensors."""
batch_size = min(batch_size, len(self.memory))
list_of_transitions = random.sample(self.memory, batch_size)
# Different sampling approach for shared memory list vs regular list
list_of_transitions = random.sample(list(self.memory), batch_size)
# if isinstance(self.memory, multiprocessing.managers.ListProxy):
# # For shared memory list, we need to be careful about thread safety
# with torch.multiprocessing.Lock():
# # Get indices first to minimize lock time
# indices = torch.randint(len(self.memory), size=(batch_size,)).tolist()
# # Convert to list to avoid multiple proxy accesses
# list_of_transitions = [self.memory[i] for i in indices]
# else:
# # For regular list, use faster random.sample
# list_of_transitions = random.sample(self.memory, batch_size)
# -- Build batched states --
batch_state = {}