Files
lerobot-clone/lerobot/common/utils/utils.py

242 lines
7.4 KiB
Python
Raw Normal View History

2024-05-15 12:13:09 +02:00
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2024-02-29 23:13:06 +00:00
import logging
import os
2024-03-26 16:13:40 +00:00
import os.path as osp
import platform
2025-03-30 15:41:39 +02:00
import select
2025-03-01 19:19:26 +01:00
import subprocess
2025-03-30 15:41:39 +02:00
import sys
from copy import copy
from datetime import datetime, timezone
2024-03-26 16:13:40 +00:00
from pathlib import Path
import numpy as np
import torch
def none_or_int(value):
if value == "None":
return None
return int(value)
def inside_slurm():
"""Check whether the python process was launched through slurm"""
# TODO(rcadene): return False for interactive mode `--pty bash`
return "SLURM_JOB_ID" in os.environ
def auto_select_torch_device() -> torch.device:
"""Tries to select automatically a torch device."""
if torch.cuda.is_available():
logging.info("Cuda backend detected, using cuda.")
return torch.device("cuda")
elif torch.backends.mps.is_available():
logging.info("Metal backend detected, using cuda.")
return torch.device("mps")
else:
logging.warning("No accelerated backend detected. Using default cpu, this will be slow.")
return torch.device("cpu")
# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
2024-04-11 17:51:35 +01:00
"""Given a string, return a torch.device with checks on whether the device is available."""
try_device = str(try_device)
match try_device:
2024-03-20 18:38:55 +01:00
case "cuda":
assert torch.cuda.is_available()
device = torch.device("cuda")
case "mps":
assert torch.backends.mps.is_available()
device = torch.device("mps")
case "cpu":
device = torch.device("cpu")
if log:
logging.warning("Using CPU, this will be slow.")
case _:
device = torch.device(try_device)
2024-03-20 18:38:55 +01:00
if log:
logging.warning(f"Using custom {try_device} device.")
2024-03-20 18:38:55 +01:00
return device
def get_safe_dtype(dtype: torch.dtype, device: str | torch.device):
"""
mps is currently not compatible with float64
"""
if isinstance(device, torch.device):
device = device.type
if device == "mps" and dtype == torch.float64:
return torch.float32
else:
return dtype
def is_torch_device_available(try_device: str) -> bool:
try_device = str(try_device) # Ensure try_device is a string
if try_device == "cuda":
return torch.cuda.is_available()
elif try_device == "mps":
return torch.backends.mps.is_available()
elif try_device == "cpu":
return True
else:
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.")
def is_amp_available(device: str):
if device in ["cuda", "cpu"]:
return True
elif device == "mps":
return False
else:
raise ValueError(f"Unknown device '{device}.")
2024-02-29 23:13:06 +00:00
def init_logging():
def custom_format(record):
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
fnameline = f"{record.pathname}:{record.lineno}"
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
return message
logging.basicConfig(level=logging.INFO)
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
formatter = logging.Formatter()
formatter.format = custom_format
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logging.getLogger().addHandler(console_handler)
def format_big_number(num, precision=0):
2024-02-29 23:13:06 +00:00
suffixes = ["", "K", "M", "B", "T", "Q"]
divisor = 1000.0
for suffix in suffixes:
if abs(num) < divisor:
return f"{num:.{precision}f}{suffix}"
2024-02-29 23:13:06 +00:00
num /= divisor
return num
2024-03-26 16:13:40 +00:00
def _relative_path_between(path1: Path, path2: Path) -> Path:
"""Returns path1 relative to path2."""
path1 = path1.absolute()
path2 = path2.absolute()
try:
return path1.relative_to(path2)
except ValueError: # most likely because path1 is not a subpath of path2
common_parts = Path(osp.commonpath([path1, path2])).parts
return Path(
"/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :]))
)
2024-04-05 10:59:32 +00:00
def print_cuda_memory_usage():
2024-04-10 16:03:39 +00:00
"""Use this function to locate and debug memory leak."""
2024-04-05 10:59:32 +00:00
import gc
gc.collect()
# Also clear the cache if you want to fully release the memory
torch.cuda.empty_cache()
print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2))
print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2))
print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2))
print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2))
def capture_timestamp_utc():
return datetime.now(timezone.utc)
def say(text, blocking=False):
2025-03-01 19:19:26 +01:00
system = platform.system()
if system == "Darwin":
cmd = ["say", text]
elif system == "Linux":
cmd = ["spd-say", text]
if blocking:
2025-03-01 19:19:26 +01:00
cmd.append("--wait")
elif system == "Windows":
cmd = [
"PowerShell",
"-Command",
"Add-Type -AssemblyName System.Speech; "
f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')",
]
2025-03-01 19:19:26 +01:00
else:
raise RuntimeError("Unsupported operating system for text-to-speech.")
if blocking:
subprocess.run(cmd, check=True)
else:
subprocess.Popen(cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0)
def log_say(text, play_sounds, blocking=False):
logging.info(text)
if play_sounds:
say(text, blocking)
def get_channel_first_image_shape(image_shape: tuple) -> tuple:
shape = copy(image_shape)
if shape[2] < shape[0] and shape[2] < shape[1]: # (h, w, c) -> (c, h, w)
shape = (shape[2], shape[0], shape[1])
elif not (shape[0] < shape[1] and shape[0] < shape[2]):
raise ValueError(image_shape)
return shape
def has_method(cls: object, method_name: str) -> bool:
return hasattr(cls, method_name) and callable(getattr(cls, method_name))
def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
"""
Return True if a given string can be converted to a numpy dtype.
"""
try:
# Attempt to convert the string to a numpy dtype
np.dtype(dtype_str)
return True
except TypeError:
# If a TypeError is raised, the string is not a valid dtype
return False
2025-03-30 15:41:39 +02:00
2025-04-02 22:27:49 +02:00
def enter_pressed() -> bool:
2025-03-30 15:41:39 +02:00
return select.select([sys.stdin], [], [], 0)[0] and sys.stdin.readline().strip() == ""
2025-04-02 22:27:49 +02:00
def move_cursor_up(lines):
"""Move the cursor up by a specified number of lines."""
print(f"\033[{lines}A", end="")