mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 04:41:24 +00:00
fix: changes to compute stats and modeling
This commit is contained in:
@@ -175,11 +175,18 @@ def train(cfg: TrainPipelineConfig):
|
||||
else:
|
||||
shuffle = True
|
||||
sampler = None
|
||||
|
||||
|
||||
keys_to_max_dim = getattr(dataset.meta, "keys_to_max_dim", {})
|
||||
keys_to_max_dim = {
|
||||
"action": (32,),
|
||||
"observation.state": (32,),
|
||||
"observation.image": (3, 1080, 1920),
|
||||
"observation.image2": (3, 1080, 1920),
|
||||
}
|
||||
collate_fn = partial(multidataset_collate_fn, keys_to_max_dim=keys_to_max_dim)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
collate_fn=collate_fn,
|
||||
num_workers=cfg.num_workers,
|
||||
batch_size=cfg.batch_size,
|
||||
shuffle=shuffle,
|
||||
|
||||
Reference in New Issue
Block a user