mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-01 03:11:29 +00:00
try fix 3
This commit is contained in:
@@ -80,6 +80,33 @@ from lerobot.utils.constants import HF_LEROBOT_HOME
|
|||||||
CODEBASE_VERSION = "v3.0"
|
CODEBASE_VERSION = "v3.0"
|
||||||
|
|
||||||
|
|
||||||
|
def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]:
|
||||||
|
"""Convert a batch from a Hugging Face dataset to torch tensors.
|
||||||
|
|
||||||
|
This transform function converts items from Hugging Face dataset format (pyarrow)
|
||||||
|
to torch tensors. Importantly, images are converted from PIL objects (H, W, C, uint8)
|
||||||
|
to a torch image representation (C, H, W, float32) in the range [0, 1]. Other
|
||||||
|
types are converted to torch.tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
items_dict (dict): A dictionary representing a batch of data from a
|
||||||
|
Hugging Face dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The batch with items converted to torch tensors.
|
||||||
|
"""
|
||||||
|
for key in items_dict:
|
||||||
|
first_item = items_dict[key][0]
|
||||||
|
if isinstance(first_item, PILImage.Image):
|
||||||
|
to_tensor = transforms.ToTensor()
|
||||||
|
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
||||||
|
elif first_item is None:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
|
||||||
|
return items_dict
|
||||||
|
|
||||||
|
|
||||||
class LeRobotDatasetMetadata:
|
class LeRobotDatasetMetadata:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -834,10 +861,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
def load_hf_dataset(self) -> datasets.Dataset:
|
def load_hf_dataset(self) -> datasets.Dataset:
|
||||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||||
|
|
||||||
# We MUST import this here to avoid circular dependency
|
|
||||||
# (utils imports lerobot_dataset for backward_compatibility)
|
|
||||||
from lerobot.datasets.utils import hf_transform_to_torch
|
|
||||||
|
|
||||||
features = get_hf_features_from_features(self.features)
|
features = get_hf_features_from_features(self.features)
|
||||||
|
|
||||||
# This is the v2.1 logic that forces an efficient, pre-decoded cache build.
|
# This is the v2.1 logic that forces an efficient, pre-decoded cache build.
|
||||||
@@ -1718,30 +1741,3 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||||||
f" Transformations: {self.image_transforms},\n"
|
f" Transformations: {self.image_transforms},\n"
|
||||||
f")"
|
f")"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]:
|
|
||||||
"""Convert a batch from a Hugging Face dataset to torch tensors.
|
|
||||||
|
|
||||||
This transform function converts items from Hugging Face dataset format (pyarrow)
|
|
||||||
to torch tensors. Importantly, images are converted from PIL objects (H, W, C, uint8)
|
|
||||||
to a torch image representation (C, H, W, float32) in the range [0, 1]. Other
|
|
||||||
types are converted to torch.tensor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
items_dict (dict): A dictionary representing a batch of data from a
|
|
||||||
Hugging Face dataset.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: The batch with items converted to torch tensors.
|
|
||||||
"""
|
|
||||||
for key in items_dict:
|
|
||||||
first_item = items_dict[key][0]
|
|
||||||
if isinstance(first_item, PILImage.Image):
|
|
||||||
to_tensor = transforms.ToTensor()
|
|
||||||
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
|
||||||
elif first_item is None:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
|
|
||||||
return items_dict
|
|
||||||
|
|||||||
Reference in New Issue
Block a user