mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 12:51:27 +00:00
remove beta annealing
This commit is contained in:
@@ -191,11 +191,9 @@ def train(cfg: TrainPipelineConfig):
|
||||
sampler = PrioritizedSampler(
|
||||
data_len=data_len,
|
||||
alpha=0.6,
|
||||
beta=0.4, # For important sampling
|
||||
eps=1e-6,
|
||||
num_samples_per_epoch=data_len,
|
||||
beta_start=0.4,
|
||||
beta_end=1.0,
|
||||
total_steps=cfg.steps,
|
||||
)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
@@ -234,7 +232,6 @@ def train(cfg: TrainPipelineConfig):
|
||||
batch[key] = batch[key].to(device, non_blocking=True)
|
||||
|
||||
if "indices" in batch:
|
||||
sampler.update_beta(step)
|
||||
is_weights = sampler.compute_is_weights(batch["indices"].cpu().tolist())
|
||||
batch["is_weights"] = is_weights
|
||||
|
||||
|
||||
Reference in New Issue
Block a user