From e3c3c165aae220a291ff032dcdc05e5d22f96928 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Tue, 11 Mar 2025 12:04:40 +0100 Subject: [PATCH] Add inital weighted sampling as prioritzed experience raplay with sum tree --- lerobot/common/datasets/lerobot_dataset.py | 3 + lerobot/common/datasets/sampler.py | 135 +++++++++++++++++++- lerobot/common/policies/act/modeling_act.py | 22 +++- lerobot/scripts/train.py | 20 ++- 4 files changed, 169 insertions(+), 11 deletions(-) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 5414c76df..6370555ad 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -749,6 +749,9 @@ class LeRobotDataset(torch.utils.data.Dataset): task_idx = item["task_index"].item() item["task"] = self.meta.tasks[task_idx] + # Add global index of frame (indices) + item["indices"] = torch.tensor(idx) + return item def __repr__(self): diff --git a/lerobot/common/datasets/sampler.py b/lerobot/common/datasets/sampler.py index 2f6c15c15..941a5561d 100644 --- a/lerobot/common/datasets/sampler.py +++ b/lerobot/common/datasets/sampler.py @@ -13,9 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterator, Union +import random +from typing import Iterator, List, Optional, Union import torch +from torch.utils.data import Sampler class EpisodeAwareSampler: @@ -59,3 +61,134 @@ class EpisodeAwareSampler: def __len__(self) -> int: return len(self.indices) + + +class SumTree: + """ + A classic sum-tree data structure for storing priorities. + Each leaf stores a sample's priority, and internal nodes store sums of children. + """ + + def __init__(self, capacity: int): + """ + Args: + capacity: Maximum number of elements. The tree size is the next power of 2 for efficiency. + """ + self.capacity = capacity + self.size = 1 + while self.size < capacity: + self.size *= 2 # Ensure power-of-two size for efficient updates + + self.tree = [0.0] * (2 * self.size) # Tree structure + + def initialize_tree(self, priorities: List[float]): + """ + Efficiently initializes the sum tree in O(n). + """ + # Set leaf values + for i, priority in enumerate(priorities): + self.tree[i + self.size] = priority + + # Compute internal node values + for i in range(self.size - 1, 0, -1): + self.tree[i] = self.tree[2 * i] + self.tree[2 * i + 1] + + def update(self, idx: int, priority: float): + """ + Update the priority at leaf index `idx` and propagate changes upwards. + """ + tree_idx = idx + self.size + self.tree[tree_idx] = priority # Set new priority + + # Propagate up, explicitly summing children + tree_idx //= 2 + while tree_idx >= 1: + self.tree[tree_idx] = self.tree[2 * tree_idx] + self.tree[2 * tree_idx + 1] + tree_idx //= 2 + + def total_priority(self) -> float: + """Returns the sum of all priorities (stored at root).""" + return self.tree[1] + + def sample(self, value: float) -> int: + """ + Samples an index where the prefix sum up to that leaf is >= `value`. + """ + value = min(max(value, 0), self.total_priority()) # Clamp value + idx = 1 + while idx < self.size: + left = 2 * idx + if self.tree[left] >= value: + idx = left + else: + value -= self.tree[left] + idx = left + 1 + return idx - self.size # Convert tree index to data index + + +class PrioritizedSampler(Sampler[int]): + """ + PyTorch Sampler that draws samples in proportion to their priority using a SumTree. + """ + + def __init__( + self, + data_len: int, + alpha: float = 0.6, + beta: float = 0.1, + eps: float = 1e-6, + replacement: bool = True, + num_samples_per_epoch: Optional[int] = None, + ): + """ + Args: + data_len: Total number of samples in the dataset. + alpha: Exponent for priority scaling. Default is 0.6. + beta: Smoothing offset to avoid excluding low-priority samples. + eps: Small constant to avoid zero priorities. + replacement: Whether to sample with replacement. + num_samples_per_epoch: Number of samples per epoch (default is data_len). + """ + self.data_len = data_len + self.alpha = alpha + self.beta = beta + self.eps = eps + self.replacement = replacement + self.num_samples_per_epoch = num_samples_per_epoch or data_len + + # Initialize difficulties and sum-tree + self.difficulties = [1.0] * data_len # Default difficulty = 1.0 + initial_priorities = [(1.0 + eps) ** alpha + beta] * data_len # Compute initial priorities + self.sumtree = SumTree(data_len) + self.sumtree.initialize_tree(initial_priorities) # Bulk load in O(n) + + def update_priorities(self, indices: List[int], difficulties: List[float]): + """ + Updates the priorities in the sum-tree. + """ + for idx, diff in zip(indices, difficulties, strict=False): + self.difficulties[idx] = diff + new_priority = (diff + self.eps) ** self.alpha + self.beta + self.sumtree.update(idx, new_priority) + + def __iter__(self) -> Iterator[int]: + """ + Samples indices based on their priority weights. + """ + total_p = self.sumtree.total_priority() + sampled_indices = set() if not self.replacement else None + + for _ in range(self.num_samples_per_epoch): + r = random.random() * total_p + idx = self.sumtree.sample(r) + + if not self.replacement: + while idx in sampled_indices: + r = random.random() * total_p + idx = self.sumtree.sample(r) + sampled_indices.add(idx) + + yield idx + + def __len__(self) -> int: + return self.num_samples_per_epoch diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index f2b16a1eb..99270c29d 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -155,11 +155,15 @@ class ACTPolicy(PreTrainedPolicy): batch = self.normalize_targets(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) - l1_loss = ( - F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) - ).mean() + elementwise_l1 = F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch[ + "action_is_pad" + ].unsqueeze(-1) + + l1_loss = elementwise_l1.mean() + + # mean over time+action_dim => per-sample array of shape (B,) + l1_per_sample = elementwise_l1.mean(dim=(1, 2)) - loss_dict = {"l1_loss": l1_loss.item()} if self.config.use_vae: # Calculate Dā‚–ā‚—(latent_pdf || standard_normal). Note: After computing the KL-divergence for # each dimension independently, we sum over the latent dimension to get the total @@ -168,9 +172,17 @@ class ACTPolicy(PreTrainedPolicy): mean_kld = ( (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() ) - loss_dict["kld_loss"] = mean_kld.item() + loss_dict = { + "l1_loss": l1_loss.item(), + "kld_loss": mean_kld.item(), + "per_sample_l1": l1_per_sample, # shape (B,) + } loss = l1_loss + mean_kld * self.config.kl_weight else: + loss_dict = { + "l1_loss": l1_loss.item(), + "per_sample_l1": l1_per_sample, # shape (B,) + } loss = l1_loss return loss, loss_dict diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index f2b1e29e3..52fed33de 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -25,7 +25,7 @@ from torch.amp import GradScaler from torch.optim import Optimizer from lerobot.common.datasets.factory import make_dataset -from lerobot.common.datasets.sampler import EpisodeAwareSampler +from lerobot.common.datasets.sampler import PrioritizedSampler from lerobot.common.datasets.utils import cycle from lerobot.common.envs.factory import make_env from lerobot.common.optim.factory import make_optimizer_and_scheduler @@ -126,6 +126,7 @@ def train(cfg: TrainPipelineConfig): logging.info("Creating dataset") dataset = make_dataset(cfg) + data_len = len(dataset) # Create environment used for evaluating checkpoints during training on simulation data. # On real-world data, no need to create an environment as evaluations are done outside train.py, @@ -165,10 +166,13 @@ def train(cfg: TrainPipelineConfig): # create dataloader for offline training if hasattr(cfg.policy, "drop_n_last_frames"): shuffle = False - sampler = EpisodeAwareSampler( - dataset.episode_data_index, - drop_n_last_frames=cfg.policy.drop_n_last_frames, - shuffle=True, + sampler = PrioritizedSampler( + data_len=data_len, + alpha=0.6, + beta=0.1, + eps=1e-6, + replacement=True, + num_samples_per_epoch=data_len, ) else: shuffle = True @@ -220,6 +224,12 @@ def train(cfg: TrainPipelineConfig): use_amp=cfg.policy.use_amp, ) + # If we have 'indices' and 'per_sample_l1' then update sampler + if "indices" in batch and "per_sample_l1" in output_dict: + indices = batch["indices"].detach().cpu().tolist() # shape (B,) + difficulties = output_dict["per_sample_l1"].detach().cpu().tolist() # shape (B,) + sampler.update_priorities(indices, difficulties) + # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we # increment `step` here. step += 1