mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 10:51:35 +00:00
try fix 3
This commit is contained in:
@@ -80,6 +80,33 @@ from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
CODEBASE_VERSION = "v3.0"
|
||||
|
||||
|
||||
def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]:
|
||||
"""Convert a batch from a Hugging Face dataset to torch tensors.
|
||||
|
||||
This transform function converts items from Hugging Face dataset format (pyarrow)
|
||||
to torch tensors. Importantly, images are converted from PIL objects (H, W, C, uint8)
|
||||
to a torch image representation (C, H, W, float32) in the range [0, 1]. Other
|
||||
types are converted to torch.tensor.
|
||||
|
||||
Args:
|
||||
items_dict (dict): A dictionary representing a batch of data from a
|
||||
Hugging Face dataset.
|
||||
|
||||
Returns:
|
||||
dict: The batch with items converted to torch tensors.
|
||||
"""
|
||||
for key in items_dict:
|
||||
first_item = items_dict[key][0]
|
||||
if isinstance(first_item, PILImage.Image):
|
||||
to_tensor = transforms.ToTensor()
|
||||
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
||||
elif first_item is None:
|
||||
pass
|
||||
else:
|
||||
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
|
||||
return items_dict
|
||||
|
||||
|
||||
class LeRobotDatasetMetadata:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -834,10 +861,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def load_hf_dataset(self) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
|
||||
# We MUST import this here to avoid circular dependency
|
||||
# (utils imports lerobot_dataset for backward_compatibility)
|
||||
from lerobot.datasets.utils import hf_transform_to_torch
|
||||
|
||||
features = get_hf_features_from_features(self.features)
|
||||
|
||||
# This is the v2.1 logic that forces an efficient, pre-decoded cache build.
|
||||
@@ -1718,30 +1741,3 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
f" Transformations: {self.image_transforms},\n"
|
||||
f")"
|
||||
)
|
||||
|
||||
|
||||
def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]:
|
||||
"""Convert a batch from a Hugging Face dataset to torch tensors.
|
||||
|
||||
This transform function converts items from Hugging Face dataset format (pyarrow)
|
||||
to torch tensors. Importantly, images are converted from PIL objects (H, W, C, uint8)
|
||||
to a torch image representation (C, H, W, float32) in the range [0, 1]. Other
|
||||
types are converted to torch.tensor.
|
||||
|
||||
Args:
|
||||
items_dict (dict): A dictionary representing a batch of data from a
|
||||
Hugging Face dataset.
|
||||
|
||||
Returns:
|
||||
dict: The batch with items converted to torch tensors.
|
||||
"""
|
||||
for key in items_dict:
|
||||
first_item = items_dict[key][0]
|
||||
if isinstance(first_item, PILImage.Image):
|
||||
to_tensor = transforms.ToTensor()
|
||||
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
||||
elif first_item is None:
|
||||
pass
|
||||
else:
|
||||
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
|
||||
return items_dict
|
||||
|
||||
Reference in New Issue
Block a user