remove beta annealing

This commit is contained in:
Pepijn
2025-03-14 13:22:22 +01:00
parent 17d12db7c4
commit 4e9b4dd380
2 changed files with 4 additions and 16 deletions

View File

@@ -191,11 +191,9 @@ def train(cfg: TrainPipelineConfig):
sampler = PrioritizedSampler(
data_len=data_len,
alpha=0.6,
beta=0.4, # For important sampling
eps=1e-6,
num_samples_per_epoch=data_len,
beta_start=0.4,
beta_end=1.0,
total_steps=cfg.steps,
)
dataloader = torch.utils.data.DataLoader(
@@ -234,7 +232,6 @@ def train(cfg: TrainPipelineConfig):
batch[key] = batch[key].to(device, non_blocking=True)
if "indices" in batch:
sampler.update_beta(step)
is_weights = sampler.compute_is_weights(batch["indices"].cpu().tolist())
batch["is_weights"] = is_weights