mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 02:41:24 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
AdilZouitine
parent
2945bbb221
commit
7c05755823
@@ -142,12 +142,8 @@ def run_server(
|
||||
)
|
||||
)
|
||||
|
||||
@app.route(
|
||||
"/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>"
|
||||
)
|
||||
def show_episode(
|
||||
dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes
|
||||
):
|
||||
@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
|
||||
def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes):
|
||||
repo_id = f"{dataset_namespace}/{dataset_name}"
|
||||
try:
|
||||
if dataset is None:
|
||||
@@ -158,9 +154,7 @@ def run_server(
|
||||
400,
|
||||
)
|
||||
dataset_version = (
|
||||
str(dataset.meta._version)
|
||||
if isinstance(dataset, LeRobotDataset)
|
||||
else dataset.codebase_version
|
||||
str(dataset.meta._version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
|
||||
)
|
||||
match = re.search(r"v(\d+)\.", dataset_version)
|
||||
if match:
|
||||
@@ -168,9 +162,7 @@ def run_server(
|
||||
if major_version < 2:
|
||||
return "Make sure to convert your LeRobotDataset to v2 & above."
|
||||
|
||||
episode_data_csv_str, columns, ignored_columns = get_episode_data(
|
||||
dataset, episode_id
|
||||
)
|
||||
episode_data_csv_str, columns, ignored_columns = get_episode_data(dataset, episode_id)
|
||||
dataset_info = {
|
||||
"repo_id": f"{dataset_namespace}/{dataset_name}",
|
||||
"num_samples": dataset.num_frames
|
||||
@@ -183,8 +175,7 @@ def run_server(
|
||||
}
|
||||
if isinstance(dataset, LeRobotDataset):
|
||||
video_paths = [
|
||||
dataset.meta.get_video_file_path(episode_id, key)
|
||||
for key in dataset.meta.video_keys
|
||||
dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys
|
||||
]
|
||||
videos_info = [
|
||||
{
|
||||
@@ -195,9 +186,7 @@ def run_server(
|
||||
]
|
||||
tasks = dataset.meta.episodes[episode_id]["tasks"]
|
||||
else:
|
||||
video_keys = [
|
||||
key for key, ft in dataset.features.items() if ft["dtype"] == "video"
|
||||
]
|
||||
video_keys = [key for key, ft in dataset.features.items() if ft["dtype"] == "video"]
|
||||
videos_info = [
|
||||
{
|
||||
"url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
|
||||
@@ -217,24 +206,16 @@ def run_server(
|
||||
)
|
||||
response.raise_for_status()
|
||||
# Split into lines and parse each line as JSON
|
||||
tasks_jsonl = [
|
||||
json.loads(line) for line in response.text.splitlines() if line.strip()
|
||||
]
|
||||
tasks_jsonl = [json.loads(line) for line in response.text.splitlines() if line.strip()]
|
||||
|
||||
filtered_tasks_jsonl = [
|
||||
row for row in tasks_jsonl if row["episode_index"] == episode_id
|
||||
]
|
||||
filtered_tasks_jsonl = [row for row in tasks_jsonl if row["episode_index"] == episode_id]
|
||||
tasks = filtered_tasks_jsonl[0]["tasks"]
|
||||
|
||||
videos_info[0]["language_instruction"] = tasks
|
||||
|
||||
if episodes is None:
|
||||
episodes = list(
|
||||
range(
|
||||
dataset.num_episodes
|
||||
if isinstance(dataset, LeRobotDataset)
|
||||
else dataset.total_episodes
|
||||
)
|
||||
range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes)
|
||||
)
|
||||
|
||||
return render_template(
|
||||
@@ -261,11 +242,7 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
|
||||
This file will be loaded by Dygraph javascript to plot data in real time."""
|
||||
columns = []
|
||||
|
||||
selected_columns = [
|
||||
col
|
||||
for col, ft in dataset.features.items()
|
||||
if ft["dtype"] in ["float32", "int32"]
|
||||
]
|
||||
selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] in ["float32", "int32"]]
|
||||
selected_columns.remove("timestamp")
|
||||
|
||||
ignored_columns = []
|
||||
@@ -286,10 +263,7 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
|
||||
else dataset.features[column_name].shape[0]
|
||||
)
|
||||
|
||||
if (
|
||||
"names" in dataset.features[column_name]
|
||||
and dataset.features[column_name]["names"]
|
||||
):
|
||||
if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]:
|
||||
column_names = dataset.features[column_name]["names"]
|
||||
while not isinstance(column_names, list):
|
||||
column_names = list(column_names.values())[0]
|
||||
@@ -312,12 +286,9 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
|
||||
else:
|
||||
repo_id = dataset.repo_id
|
||||
|
||||
url = (
|
||||
f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
|
||||
+ dataset.data_path.format(
|
||||
episode_chunk=int(episode_index) // dataset.chunks_size,
|
||||
episode_index=episode_index,
|
||||
)
|
||||
url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format(
|
||||
episode_chunk=int(episode_index) // dataset.chunks_size,
|
||||
episode_index=episode_index,
|
||||
)
|
||||
df = pd.read_parquet(url)
|
||||
data = df[selected_columns] # Select specific columns
|
||||
@@ -350,9 +321,7 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
|
||||
]
|
||||
|
||||
|
||||
def get_episode_language_instruction(
|
||||
dataset: LeRobotDataset, ep_index: int
|
||||
) -> list[str]:
|
||||
def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]:
|
||||
# check if the dataset has language instructions
|
||||
if "language_instruction" not in dataset.features:
|
||||
return None
|
||||
@@ -363,9 +332,7 @@ def get_episode_language_instruction(
|
||||
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
|
||||
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
|
||||
# with the tf.tensor appearing in the string
|
||||
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix(
|
||||
"', shape=(), dtype=string)"
|
||||
)
|
||||
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)")
|
||||
|
||||
|
||||
def get_dataset_info(repo_id: str) -> IterableNamespace:
|
||||
@@ -401,9 +368,7 @@ def visualize_dataset_html(
|
||||
if force_override:
|
||||
shutil.rmtree(output_dir)
|
||||
else:
|
||||
logging.info(
|
||||
f"Output directory already exists. Loading from it: '{output_dir}'"
|
||||
)
|
||||
logging.info(f"Output directory already exists. Loading from it: '{output_dir}'")
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user