fix: changes to compute stats and modeling

This commit is contained in:
danaaubakirova
2025-07-11 15:50:22 +02:00
parent 008b592545
commit 7848b15bfb
6 changed files with 86 additions and 10 deletions

View File

@@ -175,11 +175,18 @@ def train(cfg: TrainPipelineConfig):
else:
shuffle = True
sampler = None
keys_to_max_dim = getattr(dataset.meta, "keys_to_max_dim", {})
keys_to_max_dim = {
"action": (32,),
"observation.state": (32,),
"observation.image": (3, 1080, 1920),
"observation.image2": (3, 1080, 1920),
}
collate_fn = partial(multidataset_collate_fn, keys_to_max_dim=keys_to_max_dim)
dataloader = torch.utils.data.DataLoader(
dataset,
collate_fn=collate_fn,
num_workers=cfg.num_workers,
batch_size=cfg.batch_size,
shuffle=shuffle,