[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-04 13:38:47 +00:00
committed by Michel Aractingi
parent bb69cb3c8c
commit 85fe8a3f4e
79 changed files with 2800 additions and 794 deletions

View File

@@ -81,7 +81,11 @@ def run_server(
static_folder: Path,
template_folder: Path,
):
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
app = Flask(
__name__,
static_folder=static_folder.resolve(),
template_folder=template_folder.resolve(),
)
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
@app.route("/")
@@ -138,8 +142,12 @@ def run_server(
)
)
@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes):
@app.route(
"/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>"
)
def show_episode(
dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes
):
repo_id = f"{dataset_namespace}/{dataset_name}"
try:
if dataset is None:
@@ -150,7 +158,9 @@ def run_server(
400,
)
dataset_version = (
str(dataset.meta._version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
str(dataset.meta._version)
if isinstance(dataset, LeRobotDataset)
else dataset.codebase_version
)
match = re.search(r"v(\d+)\.", dataset_version)
if match:
@@ -158,7 +168,9 @@ def run_server(
if major_version < 2:
return "Make sure to convert your LeRobotDataset to v2 & above."
episode_data_csv_str, columns, ignored_columns = get_episode_data(dataset, episode_id)
episode_data_csv_str, columns, ignored_columns = get_episode_data(
dataset, episode_id
)
dataset_info = {
"repo_id": f"{dataset_namespace}/{dataset_name}",
"num_samples": dataset.num_frames
@@ -171,18 +183,23 @@ def run_server(
}
if isinstance(dataset, LeRobotDataset):
video_paths = [
dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys
dataset.meta.get_video_file_path(episode_id, key)
for key in dataset.meta.video_keys
]
videos_info = [
{
"url": url_for("static", filename=str(video_path).replace("\\", "/")),
"url": url_for(
"static", filename=str(video_path).replace("\\", "/")
),
"filename": video_path.parent.name,
}
for video_path in video_paths
]
tasks = dataset.meta.episodes[episode_id]["tasks"]
else:
video_keys = [key for key, ft in dataset.features.items() if ft["dtype"] == "video"]
video_keys = [
key for key, ft in dataset.features.items() if ft["dtype"] == "video"
]
videos_info = [
{
"url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
@@ -197,20 +214,29 @@ def run_server(
]
response = requests.get(
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl", timeout=5
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl",
timeout=5,
)
response.raise_for_status()
# Split into lines and parse each line as JSON
tasks_jsonl = [json.loads(line) for line in response.text.splitlines() if line.strip()]
tasks_jsonl = [
json.loads(line) for line in response.text.splitlines() if line.strip()
]
filtered_tasks_jsonl = [row for row in tasks_jsonl if row["episode_index"] == episode_id]
filtered_tasks_jsonl = [
row for row in tasks_jsonl if row["episode_index"] == episode_id
]
tasks = filtered_tasks_jsonl[0]["tasks"]
videos_info[0]["language_instruction"] = tasks
if episodes is None:
episodes = list(
range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes)
range(
dataset.num_episodes
if isinstance(dataset, LeRobotDataset)
else dataset.total_episodes
)
)
return render_template(
@@ -237,7 +263,11 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
This file will be loaded by Dygraph javascript to plot data in real time."""
columns = []
selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] in ["float32", "int32"]]
selected_columns = [
col
for col, ft in dataset.features.items()
if ft["dtype"] in ["float32", "int32"]
]
selected_columns.remove("timestamp")
ignored_columns = []
@@ -258,7 +288,10 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
else dataset.features[column_name].shape[0]
)
if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]:
if (
"names" in dataset.features[column_name]
and dataset.features[column_name]["names"]
):
column_names = dataset.features[column_name]["names"]
while not isinstance(column_names, list):
column_names = list(column_names.values())[0]
@@ -281,8 +314,12 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
else:
repo_id = dataset.repo_id
url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format(
episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index
url = (
f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
+ dataset.data_path.format(
episode_chunk=int(episode_index) // dataset.chunks_size,
episode_index=episode_index,
)
)
df = pd.read_parquet(url)
data = df[selected_columns] # Select specific columns
@@ -315,7 +352,9 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
]
def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]:
def get_episode_language_instruction(
dataset: LeRobotDataset, ep_index: int
) -> list[str]:
# check if the dataset has language instructions
if "language_instruction" not in dataset.features:
return None
@@ -326,12 +365,15 @@ def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) ->
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
# with the tf.tensor appearing in the string
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)")
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix(
"', shape=(), dtype=string)"
)
def get_dataset_info(repo_id: str) -> IterableNamespace:
response = requests.get(
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json", timeout=5
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json",
timeout=5,
)
response.raise_for_status() # Raises an HTTPError for bad responses
dataset_info = response.json()
@@ -361,7 +403,9 @@ def visualize_dataset_html(
if force_override:
shutil.rmtree(output_dir)
else:
logging.info(f"Output directory already exists. Loading from it: '{output_dir}'")
logging.info(
f"Output directory already exists. Loading from it: '{output_dir}'"
)
output_dir.mkdir(parents=True, exist_ok=True)