diff --git a/src/lerobot/rl/buffer.py b/src/lerobot/rl/buffer.py index 48ea2f3ef..8c6d5dc4a 100644 --- a/src/lerobot/rl/buffer.py +++ b/src/lerobot/rl/buffer.py @@ -15,6 +15,7 @@ # limitations under the License. import functools +import threading from collections.abc import Callable, Sequence from contextlib import suppress from typing import TypedDict @@ -115,6 +116,7 @@ class ReplayBuffer: self.size = 0 self.initialized = False self.optimize_memory = optimize_memory + self._lock = threading.Lock() # Track episode boundaries for memory optimization self.episode_ends = torch.zeros(capacity, dtype=torch.bool, device=storage_device) @@ -198,68 +200,75 @@ class ReplayBuffer: complementary_info: dict[str, torch.Tensor] | None = None, ): """Saves a transition, ensuring tensors are stored on the designated storage device.""" - # Initialize storage if this is the first transition - if not self.initialized: - self._initialize_storage(state=state, action=action, complementary_info=complementary_info) + with self._lock: + # Initialize storage if this is the first transition + if not self.initialized: + self._initialize_storage(state=state, action=action, complementary_info=complementary_info) - # Store the transition in pre-allocated tensors - for key in self.states: - self.states[key][self.position].copy_(state[key].squeeze(dim=0)) + # Store the transition in pre-allocated tensors + for key in self.states: + self.states[key][self.position].copy_(state[key].squeeze(dim=0)) - if not self.optimize_memory: - # Only store next_states if not optimizing memory - self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0)) + if not self.optimize_memory: + # Only store next_states if not optimizing memory + self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0)) - self.actions[self.position].copy_(action.squeeze(dim=0)) - self.rewards[self.position] = reward - self.dones[self.position] = done - self.truncateds[self.position] = truncated + self.actions[self.position].copy_(action.squeeze(dim=0)) + self.rewards[self.position] = reward + self.dones[self.position] = done + self.truncateds[self.position] = truncated - # Handle complementary_info if provided and storage is initialized - if complementary_info is not None and self.has_complementary_info: - # Store the complementary_info - for key in self.complementary_info_keys: - if key in complementary_info: - value = complementary_info[key] - if isinstance(value, torch.Tensor): - self.complementary_info[key][self.position].copy_(value.squeeze(dim=0)) - elif isinstance(value, (int | float)): - self.complementary_info[key][self.position] = value + # Handle complementary_info if provided and storage is initialized + if complementary_info is not None and self.has_complementary_info: + for key in self.complementary_info_keys: + if key in complementary_info: + value = complementary_info[key] + if isinstance(value, torch.Tensor): + self.complementary_info[key][self.position].copy_(value.squeeze(dim=0)) + elif isinstance(value, (int | float)): + self.complementary_info[key][self.position] = value - self.position = (self.position + 1) % self.capacity - self.size = min(self.size + 1, self.capacity) + self.position = (self.position + 1) % self.capacity + self.size = min(self.size + 1, self.capacity) def sample(self, batch_size: int) -> BatchTransition: """Sample a random batch of transitions and collate them into batched tensors.""" if not self.initialized: raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.") - batch_size = min(batch_size, self.size) - high = max(0, self.size - 1) if self.optimize_memory and self.size < self.capacity else self.size + with self._lock: + batch_size = min(batch_size, self.size) + high = max(0, self.size - 1) if self.optimize_memory and self.size < self.capacity else self.size - # Random indices for sampling - create on the same device as storage - idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device) + idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device) - # Identify image keys that need augmentation - image_keys = [k for k in self.states if k.startswith(OBS_IMAGE)] if self.use_drq else [] + image_keys = [k for k in self.states if k.startswith(OBS_IMAGE)] if self.use_drq else [] - # Create batched state and next_state - batch_state = {} - batch_next_state = {} + batch_state = {} + batch_next_state = {} - # First pass: load all state tensors to target device - for key in self.states: - batch_state[key] = self.states[key][idx].to(self.device) + for key in self.states: + batch_state[key] = self.states[key][idx].to(self.device) - if not self.optimize_memory: - # Standard approach - load next_states directly - batch_next_state[key] = self.next_states[key][idx].to(self.device) - else: - # Memory-optimized approach - get next_state from the next index - next_idx = (idx + 1) % self.capacity - batch_next_state[key] = self.states[key][next_idx].to(self.device) + if not self.optimize_memory: + batch_next_state[key] = self.next_states[key][idx].to(self.device) + else: + next_idx = (idx + 1) % self.capacity + batch_next_state[key] = self.states[key][next_idx].to(self.device) + + # Sample other tensors + batch_actions = self.actions[idx].to(self.device) + batch_rewards = self.rewards[idx].to(self.device) + batch_dones = self.dones[idx].to(self.device).float() + batch_truncateds = self.truncateds[idx].to(self.device).float() + + # Sample complementary_info if available + batch_complementary_info = None + if self.has_complementary_info: + batch_complementary_info = {} + for key in self.complementary_info_keys: + batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device) - # Apply image augmentation in a batched way if needed if self.use_drq and image_keys: # Concatenate all images from state and next_state all_images = [] @@ -280,19 +289,6 @@ class ReplayBuffer: # Next states start after the states at index (i*2+1)*batch_size and also take up batch_size slots batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size] - # Sample other tensors - batch_actions = self.actions[idx].to(self.device) - batch_rewards = self.rewards[idx].to(self.device) - batch_dones = self.dones[idx].to(self.device).float() - batch_truncateds = self.truncateds[idx].to(self.device).float() - - # Sample complementary_info if available - batch_complementary_info = None - if self.has_complementary_info: - batch_complementary_info = {} - for key in self.complementary_info_keys: - batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device) - return BatchTransition( state=batch_state, action=batch_actions,