[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-04 13:38:47 +00:00
committed by Michel Aractingi
parent bb69cb3c8c
commit 85fe8a3f4e
79 changed files with 2800 additions and 794 deletions

View File

@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import random
from typing import Any, Callable, Optional, Sequence, TypedDict
import io
@@ -737,7 +736,6 @@ def concatenate_batch_transitions(
if __name__ == "__main__":
import numpy as np
from tempfile import TemporaryDirectory
# ===== Test 1: Create and use a synthetic ReplayBuffer =====
@@ -1139,7 +1137,7 @@ if __name__ == "__main__":
savings_percent = (std_mem - opt_mem) / std_mem * 100
print(f"\nMemory optimization result:")
print("\nMemory optimization result:")
print(f"- Standard buffer state memory: {std_mem / (1024 * 1024):.2f} MB")
print(f"- Optimized buffer state memory: {opt_mem / (1024 * 1024):.2f} MB")
print(f"- Memory savings for state tensors: {savings_percent:.1f}%")

View File

@@ -225,7 +225,9 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Crop rectangular ROIs from a LeRobot dataset.")
parser = argparse.ArgumentParser(
description="Crop rectangular ROIs from a LeRobot dataset."
)
parser.add_argument(
"--repo-id",
type=str,
@@ -247,7 +249,9 @@ if __name__ == "__main__":
args = parser.parse_args()
local_files_only = args.root is not None
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root, local_files_only=local_files_only)
dataset = LeRobotDataset(
repo_id=args.repo_id, root=args.root, local_files_only=local_files_only
)
images = get_image_from_lerobot_dataset(dataset)
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
@@ -256,7 +260,7 @@ if __name__ == "__main__":
if args.crop_params_path is None:
rois = select_square_roi_for_images(images)
else:
with open(args.crop_params_path, "r") as f:
with open(args.crop_params_path) as f:
rois = json.load(f)
# rois = {

View File

@@ -31,7 +31,9 @@ def find_joint_bounds(
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.imshow(
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
)
cv2.waitKey(1)
timestamp = time.perf_counter() - start_episode_t
@@ -57,7 +59,12 @@ if __name__ == "__main__":
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds")
parser.add_argument(
"--control-time-s",
type=float,
default=20,
help="Maximum episode length in seconds",
)
args = parser.parse_args()
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)

View File

@@ -146,7 +146,7 @@ def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
def initialize_replay_buffer(
cfg: DictConfig, logger: Logger, device: str, storage_device:str
cfg: DictConfig, logger: Logger, device: str, storage_device: str
) -> ReplayBuffer:
if not cfg.resume:
return ReplayBuffer(

View File

@@ -10,7 +10,9 @@ from typing import Any
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, torch.Tensor]:
def preprocess_maniskill_observation(
observations: dict[str, np.ndarray],
) -> dict[str, torch.Tensor]:
"""Convert environment observation to LeRobot format observation.
Args:
observation: Dictionary of observation batches from a Gym vector environment.
@@ -62,7 +64,9 @@ class ManiSkillCompat(gym.Wrapper):
new_action_space_shape = env.action_space.shape[-1]
new_low = np.squeeze(env.action_space.low, axis=0)
new_high = np.squeeze(env.action_space.high, axis=0)
self.action_space = gym.spaces.Box(low=new_low, high=new_high, shape=(new_action_space_shape,))
self.action_space = gym.spaces.Box(
low=new_low, high=new_high, shape=(new_action_space_shape,)
)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
@@ -81,7 +85,9 @@ class ManiSkillCompat(gym.Wrapper):
class ManiSkillActionWrapper(gym.ActionWrapper):
def __init__(self, env):
super().__init__(env)
self.action_space = gym.spaces.Tuple(spaces=(env.action_space, gym.spaces.Discrete(2)))
self.action_space = gym.spaces.Tuple(
spaces=(env.action_space, gym.spaces.Discrete(2))
)
def action(self, action):
action, telop = action
@@ -95,7 +101,9 @@ class ManiSkillMultiplyActionWrapper(gym.Wrapper):
action_space_agent: gym.spaces.Box = env.action_space[0]
action_space_agent.low = action_space_agent.low * multiply_factor
action_space_agent.high = action_space_agent.high * multiply_factor
self.action_space = gym.spaces.Tuple(spaces=(action_space_agent, gym.spaces.Discrete(2)))
self.action_space = gym.spaces.Tuple(
spaces=(action_space_agent, gym.spaces.Discrete(2))
)
def step(self, action):
if isinstance(action, tuple):
@@ -137,7 +145,9 @@ def make_maniskill(
env = ManiSkillObservationWrapper(env, device=cfg.env.device)
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
env._max_episode_steps = env.max_episode_steps = (
50 # gym_utils.find_max_episode_steps_value(env)
)
env.unwrapped.metadata["render_fps"] = 20
env = ManiSkillCompat(env)
env = ManiSkillActionWrapper(env)
@@ -149,10 +159,11 @@ def make_maniskill(
if __name__ == "__main__":
import argparse
import hydra
from omegaconf import OmegaConf
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="lerobot/configs/env/maniskill_example.yaml")
parser.add_argument(
"--config", type=str, default="lerobot/configs/env/maniskill_example.yaml"
)
args = parser.parse_args()
# Initialize config